Commit 1345fab2 authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #1263 canceled with stages
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from util import datapoints
import transforms as _transforms
from transforms import _functional_tensor as _FT
from transforms.v2 import functional as F, Transform
from transforms import AutoAugmentPolicy, InterpolationMode
from transforms.v2.functional._geometry import _check_interpolation
from transforms.v2.functional._meta import get_spatial_size
from ._utils import _setup_fill_arg
from .utils import check_type, is_simple_tensor
class _AutoAugmentBase(Transform):
def __init__(
self,
*,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
) -> None:
super().__init__()
self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
keys = tuple(dct.keys())
key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]
def _flatten_and_extract_image_or_video(
self,
inputs: Any,
unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBox, datapoints.Mask),
) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints._ImageType, datapoints._VideoType]]:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
needs_transform_list = self._needs_transform_list(flat_inputs)
image_or_videos = []
for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)):
if needs_transform and check_type(
inpt,
(
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
),
):
image_or_videos.append((idx, inpt))
elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
if not image_or_videos:
raise TypeError("Found no image in the sample.")
if len(image_or_videos) > 1:
raise TypeError(
f"Auto augment transformations are only properly defined for a single image or video, "
f"but found {len(image_or_videos)}."
)
idx, image_or_video = image_or_videos[0]
return (flat_inputs, spec, idx), image_or_video
def _unflatten_and_insert_image_or_video(
self,
flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
image_or_video: Union[datapoints._ImageType, datapoints._VideoType],
) -> Any:
flat_inputs, spec, idx = flat_inputs_with_spec
flat_inputs[idx] = image_or_video
return tree_unflatten(flat_inputs, spec)
def _apply_image_or_video_transform(
self,
image: Union[datapoints._ImageType, datapoints._VideoType],
transform_id: str,
magnitude: float,
interpolation: Union[InterpolationMode, int],
fill: Dict[Type, datapoints._FillTypeJIT],
) -> Union[datapoints._ImageType, datapoints._VideoType]:
fill_ = fill[type(image)]
if transform_id == "Identity":
return image
elif transform_id == "ShearX":
# magnitude should be arctan(magnitude)
# official autoaug: (1, level, 0, 0, 1, 0)
# https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
# compared to
# torchvision: (1, tan(level), 0, 0, 1, 0)
# https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
return F.affine(
image,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[math.degrees(math.atan(magnitude)), 0.0],
interpolation=interpolation,
fill=fill_,
center=[0, 0],
)
elif transform_id == "ShearY":
# magnitude should be arctan(magnitude)
# See above
return F.affine(
image,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[0.0, math.degrees(math.atan(magnitude))],
interpolation=interpolation,
fill=fill_,
center=[0, 0],
)
elif transform_id == "TranslateX":
return F.affine(
image,
angle=0.0,
translate=[int(magnitude), 0],
scale=1.0,
interpolation=interpolation,
shear=[0.0, 0.0],
fill=fill_,
)
elif transform_id == "TranslateY":
return F.affine(
image,
angle=0.0,
translate=[0, int(magnitude)],
scale=1.0,
interpolation=interpolation,
shear=[0.0, 0.0],
fill=fill_,
)
elif transform_id == "Rotate":
return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_)
elif transform_id == "Brightness":
return F.adjust_brightness(image, brightness_factor=1.0 + magnitude)
elif transform_id == "Color":
return F.adjust_saturation(image, saturation_factor=1.0 + magnitude)
elif transform_id == "Contrast":
return F.adjust_contrast(image, contrast_factor=1.0 + magnitude)
elif transform_id == "Sharpness":
return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude)
elif transform_id == "Posterize":
return F.posterize(image, bits=int(magnitude))
elif transform_id == "Solarize":
bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0
return F.solarize(image, threshold=bound * magnitude)
elif transform_id == "AutoContrast":
return F.autocontrast(image)
elif transform_id == "Equalize":
return F.equalize(image)
elif transform_id == "Invert":
return F.invert(image)
else:
raise ValueError(f"No transform available for {transform_id}")
class AutoAugment(_AutoAugmentBase):
r"""[BETA] AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
.. v2betastatus:: AutoAugment transform
This transformation works on images and videos only.
If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
policy (AutoAugmentPolicy, optional): Desired policy enum defined by
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
_v1_transform_cls = _transforms.AutoAugment
_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
True,
),
"TranslateY": (
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
True,
),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
"Invert": (lambda num_bins, height, width: None, False),
}
def __init__(
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy
self._policies = self._get_policies(policy)
def _get_policies(
self, policy: AutoAugmentPolicy
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
if policy == AutoAugmentPolicy.IMAGENET:
return [
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
]
elif policy == AutoAugmentPolicy.CIFAR10:
return [
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
]
elif policy == AutoAugmentPolicy.SVHN:
return [
(("ShearX", 0.9, 4), ("Invert", 0.2, None)),
(("ShearY", 0.9, 8), ("Invert", 0.7, None)),
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
(("ShearY", 0.9, 8), ("Invert", 0.4, None)),
(("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
(("ShearY", 0.8, 8), ("Invert", 0.7, None)),
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
(("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
(("Invert", 0.6, None), ("Rotate", 0.8, 4)),
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
(("ShearX", 0.1, 6), ("Invert", 0.6, None)),
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
(("ShearY", 0.8, 4), ("Invert", 0.8, None)),
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
]
else:
raise ValueError(f"The provided policy {policy} is not recognized.")
def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_spatial_size(image_or_video)
policy = self._policies[int(torch.randint(len(self._policies), ()))]
for transform_id, probability, magnitude_idx in policy:
if not torch.rand(()) <= probability:
continue
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
magnitudes = magnitudes_fn(10, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
image_or_video = self._apply_image_or_video_transform(
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
class RandAugment(_AutoAugmentBase):
r"""[BETA] RandAugment data augmentation method based on
`"RandAugment: Practical automated data augmentation with a reduced search space"
<https://arxiv.org/abs/1909.13719>`_.
.. v2betastatus:: RandAugment transform
This transformation works on images and videos only.
If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
num_ops (int, optional): Number of augmentation transformations to apply sequentially.
magnitude (int, optional): Magnitude for all the transformations.
num_magnitude_bins (int, optional): The number of different magnitude values.
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
_v1_transform_cls = _transforms.RandAugment
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
True,
),
"TranslateY": (
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
True,
),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
}
def __init__(
self,
num_ops: int = 2,
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_spatial_size(image_or_video)
for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[self.magnitude])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
image_or_video = self._apply_image_or_video_transform(
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
class TrivialAugmentWide(_AutoAugmentBase):
r"""[BETA] Dataset-independent data-augmentation with TrivialAugment Wide, as described in
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
.. v2betastatus:: TrivialAugmentWide transform
This transformation works on images and videos only.
If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
num_magnitude_bins (int, optional): The number of different magnitude values.
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
_v1_transform_cls = _transforms.TrivialAugmentWide
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
"TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True),
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Posterize": (
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
}
def __init__(
self,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_spatial_size(image_or_video)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
image_or_video = self._apply_image_or_video_transform(
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
class AugMix(_AutoAugmentBase):
r"""[BETA] AugMix data augmentation method based on
`"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
.. v2betastatus:: AugMix transform
This transformation works on images and videos only.
If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
severity (int, optional): The severity of base augmentation operators. Default is ``3``.
mixture_width (int, optional): The number of augmentation chains. Default is ``3``.
chain_depth (int, optional): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
Default is ``-1``.
alpha (float, optional): The hyperparameter for the probability distributions. Default is ``1.0``.
all_ops (bool, optional): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
_v1_transform_cls = _transforms.AugMix
_PARTIAL_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True),
"TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Posterize": (
lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
}
_AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = {
**_PARTIAL_AUGMENTATION_SPACE,
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
}
def __init__(
self,
severity: int = 3,
mixture_width: int = 3,
chain_depth: int = -1,
alpha: float = 1.0,
all_ops: bool = True,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10
if not (1 <= severity <= self._PARAMETER_MAX):
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
self.severity = severity
self.mixture_width = mixture_width
self.chain_depth = chain_depth
self.alpha = alpha
self.all_ops = all_ops
def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
# Must be on a separate method so that we can overwrite it in tests.
return torch._sample_dirichlet(params)
def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_spatial_size(orig_image_or_video)
if isinstance(orig_image_or_video, torch.Tensor):
image_or_video = orig_image_or_video
else: # isinstance(inpt, PIL.Image.Image):
image_or_video = F.pil_to_tensor(orig_image_or_video)
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
orig_dims = list(image_or_video.shape)
expected_ndim = 5 if isinstance(orig_image_or_video, datapoints.Video) else 4
batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
# Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
# Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of
# augmented image or video.
m = self._sample_dirichlet(
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
)
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.
combined_weights = self._sample_dirichlet(
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
) * m[:, 1].reshape([batch_dims[0], -1])
mix = m[:, 0].reshape(batch_dims) * batch
for i in range(self.mixture_width):
aug = batch
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
for _ in range(depth):
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
aug = self._apply_image_or_video_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)):
mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type]
elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_image_pil(mix)
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix)
import collections.abc
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import PIL.Image
import torch
from util import datapoints
import transforms as _transforms
from transforms.v2 import functional as F, Transform
from ._transform import _RandomApplyTransform
from .utils import is_simple_tensor, query_chw
class Grayscale(Transform):
"""[BETA] Convert images or videos to grayscale.
.. v2betastatus:: Grayscale transform
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 3 or 1, H, W] shape, where ... means an arbitrary number of leading dimensions
Args:
num_output_channels (int): (1 or 3) number of channels desired for output image
"""
_v1_transform_cls = _transforms.Grayscale
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, num_output_channels: int = 1):
super().__init__()
self.num_output_channels = num_output_channels
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
class RandomGrayscale(_RandomApplyTransform):
"""[BETA] Randomly convert image or videos to grayscale with a probability of p (default 0.1).
.. v2betastatus:: RandomGrayscale transform
If the input is a :class:`torch.Tensor`, it is expected to have [..., 3 or 1, H, W] shape,
where ... means an arbitrary number of leading dimensions
The output has the same number of channels as the input.
Args:
p (float): probability that image should be converted to grayscale.
"""
_v1_transform_cls = _transforms.RandomGrayscale
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, p: float = 0.1) -> None:
super().__init__(p=p)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_input_channels, *_ = query_chw(flat_inputs)
return dict(num_input_channels=num_input_channels)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
class ColorJitter(Transform):
"""[BETA] Randomly change the brightness, contrast, saturation and hue of an image or video.
.. v2betastatus:: ColorJitter transform
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
Args:
brightness (float or tuple of float (min, max)): How much to jitter brightness.
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
or the given [min, max]. Should be non negative numbers.
contrast (float or tuple of float (min, max)): How much to jitter contrast.
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
or the given [min, max]. Should be non-negative numbers.
saturation (float or tuple of float (min, max)): How much to jitter saturation.
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
or the given [min, max]. Should be non negative numbers.
hue (float or tuple of float (min, max)): How much to jitter hue.
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
thus it does not work if you normalize your image to an interval with negative values,
or use an interpolation that generates negative values before using this function.
"""
_v1_transform_cls = _transforms.ColorJitter
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return {attr: value or 0 for attr, value in super()._extract_params_for_v1_transform().items()}
def __init__(
self,
brightness: Optional[Union[float, Sequence[float]]] = None,
contrast: Optional[Union[float, Sequence[float]]] = None,
saturation: Optional[Union[float, Sequence[float]]] = None,
hue: Optional[Union[float, Sequence[float]]] = None,
) -> None:
super().__init__()
self.brightness = self._check_input(brightness, "brightness")
self.contrast = self._check_input(contrast, "contrast")
self.saturation = self._check_input(saturation, "saturation")
self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
def _check_input(
self,
value: Optional[Union[float, Sequence[float]]],
name: str,
center: float = 1.0,
bound: Tuple[float, float] = (0, float("inf")),
clip_first_on_zero: bool = True,
) -> Optional[Tuple[float, float]]:
if value is None:
return None
if isinstance(value, (int, float)):
if value < 0:
raise ValueError(f"If {name} is a single number, it must be non negative.")
value = [center - value, center + value]
if clip_first_on_zero:
value[0] = max(value[0], 0.0)
elif isinstance(value, collections.abc.Sequence) and len(value) == 2:
value = [float(v) for v in value]
else:
raise TypeError(f"{name}={value} should be a single number or a sequence with length 2.")
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))
@staticmethod
def _generate_value(left: float, right: float) -> float:
return torch.empty(1).uniform_(left, right).item()
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
fn_idx = torch.randperm(4)
b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1])
c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1])
s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1])
h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1])
return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = inpt
brightness_factor = params["brightness_factor"]
contrast_factor = params["contrast_factor"]
saturation_factor = params["saturation_factor"]
hue_factor = params["hue_factor"]
for fn_id in params["fn_idx"]:
if fn_id == 0 and brightness_factor is not None:
output = F.adjust_brightness(output, brightness_factor=brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
output = F.adjust_contrast(output, contrast_factor=contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
output = F.adjust_saturation(output, saturation_factor=saturation_factor)
elif fn_id == 3 and hue_factor is not None:
output = F.adjust_hue(output, hue_factor=hue_factor)
return output
# TODO: This class seems to be untested
class RandomPhotometricDistort(Transform):
"""[BETA] Randomly distorts the image or video as used in `SSD: Single Shot
MultiBox Detector <https://arxiv.org/abs/1512.02325>`_.
.. v2betastatus:: RandomPhotometricDistort transform
This transform relies on :class:`~torchvision.transforms.v2.ColorJitter`
under the hood to adjust the contrast, saturation, hue, brightness, and also
randomly permutes channels.
Args:
brightness (tuple of float (min, max), optional): How much to jitter brightness.
brightness_factor is chosen uniformly from [min, max]. Should be non negative numbers.
contrast tuple of float (min, max), optional): How much to jitter contrast.
contrast_factor is chosen uniformly from [min, max]. Should be non-negative numbers.
saturation (tuple of float (min, max), optional): How much to jitter saturation.
saturation_factor is chosen uniformly from [min, max]. Should be non negative numbers.
hue (tuple of float (min, max), optional): How much to jitter hue.
hue_factor is chosen uniformly from [min, max]. Should have -0.5 <= min <= max <= 0.5.
To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
thus it does not work if you normalize your image to an interval with negative values,
or use an interpolation that generates negative values before using this function.
p (float, optional) probability each distortion operation (contrast, saturation, ...) to be applied.
Default is 0.5.
"""
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(
self,
brightness: Tuple[float, float] = (0.875, 1.125),
contrast: Tuple[float, float] = (0.5, 1.5),
saturation: Tuple[float, float] = (0.5, 1.5),
hue: Tuple[float, float] = (-0.05, 0.05),
p: float = 0.5,
):
super().__init__()
self.brightness = brightness
self.contrast = contrast
self.hue = hue
self.saturation = saturation
self.p = p
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs)
params: Dict[str, Any] = {
key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None
for key, range in [
("brightness_factor", self.brightness),
("contrast_factor", self.contrast),
("saturation_factor", self.saturation),
("hue_factor", self.hue),
]
}
params["contrast_before"] = bool(torch.rand(()) < 0.5)
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
return params
def _permute_channels(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], permutation: torch.Tensor
) -> Union[datapoints._ImageType, datapoints._VideoType]:
orig_inpt = inpt
if isinstance(orig_inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)
# TODO: Find a better fix than as_subclass???
output = inpt[..., permutation, :, :].as_subclass(type(inpt))
if isinstance(orig_inpt, PIL.Image.Image):
output = F.to_image_pil(output)
return output
def _transform(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Union[datapoints._ImageType, datapoints._VideoType]:
if params["brightness_factor"] is not None:
inpt = F.adjust_brightness(inpt, brightness_factor=params["brightness_factor"])
if params["contrast_factor"] is not None and params["contrast_before"]:
inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"])
if params["saturation_factor"] is not None:
inpt = F.adjust_saturation(inpt, saturation_factor=params["saturation_factor"])
if params["hue_factor"] is not None:
inpt = F.adjust_hue(inpt, hue_factor=params["hue_factor"])
if params["contrast_factor"] is not None and not params["contrast_before"]:
inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"])
if params["channel_permutation"] is not None:
inpt = self._permute_channels(inpt, permutation=params["channel_permutation"])
return inpt
class RandomEqualize(_RandomApplyTransform):
"""[BETA] Equalize the histogram of the given image or video with a given probability.
.. v2betastatus:: RandomEqualize transform
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
Args:
p (float): probability of the image being equalized. Default value is 0.5
"""
_v1_transform_cls = _transforms.RandomEqualize
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.equalize(inpt)
class RandomInvert(_RandomApplyTransform):
"""[BETA] Inverts the colors of the given image or video with a given probability.
.. v2betastatus:: RandomInvert transform
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
p (float): probability of the image being color inverted. Default value is 0.5
"""
_v1_transform_cls = _transforms.RandomInvert
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.invert(inpt)
class RandomPosterize(_RandomApplyTransform):
"""[BETA] Posterize the image or video with a given probability by reducing the
number of bits for each color channel.
.. v2betastatus:: RandomPosterize transform
If the input is a :class:`torch.Tensor`, it should be of type torch.uint8,
and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
bits (int): number of bits to keep for each channel (0-8)
p (float): probability of the image being posterized. Default value is 0.5
"""
_v1_transform_cls = _transforms.RandomPosterize
def __init__(self, bits: int, p: float = 0.5) -> None:
super().__init__(p=p)
self.bits = bits
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.posterize(inpt, bits=self.bits)
class RandomSolarize(_RandomApplyTransform):
"""[BETA] Solarize the image or video with a given probability by inverting all pixel
values above a threshold.
.. v2betastatus:: RandomSolarize transform
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
threshold (float): all pixels equal or above this value are inverted.
p (float): probability of the image being solarized. Default value is 0.5
"""
_v1_transform_cls = _transforms.RandomSolarize
def __init__(self, threshold: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.threshold = threshold
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.solarize(inpt, threshold=self.threshold)
class RandomAutocontrast(_RandomApplyTransform):
"""[BETA] Autocontrast the pixels of the given image or video with a given probability.
.. v2betastatus:: RandomAutocontrast transform
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
p (float): probability of the image being autocontrasted. Default value is 0.5
"""
_v1_transform_cls = _transforms.RandomAutocontrast
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.autocontrast(inpt)
class RandomAdjustSharpness(_RandomApplyTransform):
"""[BETA] Adjust the sharpness of the image or video with a given probability.
.. v2betastatus:: RandomAdjustSharpness transform
If the input is a :class:`torch.Tensor`,
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
sharpness_factor (float): How much to adjust the sharpness. Can be
any non-negative number. 0 gives a blurred image, 1 gives the
original image while 2 increases the sharpness by a factor of 2.
p (float): probability of the image being sharpened. Default value is 0.5
"""
_v1_transform_cls = _transforms.RandomAdjustSharpness
def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.sharpness_factor = sharpness_factor
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.adjust_sharpness(inpt, sharpness_factor=self.sharpness_factor)
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import torch
from torch import nn
import transforms as _transforms
from transforms.v2 import Transform
class Compose(Transform):
"""[BETA] Composes several transforms together.
.. v2betastatus:: Compose transform
This transform does not support torchscript.
Please, see the note below.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> ])
.. note::
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
>>> transforms = torch.nn.Sequential(
>>> transforms.CenterCrop(10),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
"""
def __init__(self, transforms: Sequence[Callable]) -> None:
super().__init__()
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
self.transforms = transforms
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for transform in self.transforms:
sample = transform(sample)
return sample
def extra_repr(self) -> str:
format_string = []
for t in self.transforms:
format_string.append(f" {t}")
return "\n".join(format_string)
class RandomApply(Transform):
"""[BETA] Apply randomly a list of transformations with a given probability.
.. v2betastatus:: RandomApply transform
.. note::
In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
transforms as shown below:
>>> transforms = transforms.RandomApply(torch.nn.ModuleList([
>>> transforms.ColorJitter(),
>>> ]), p=0.3)
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
Args:
transforms (sequence or torch.nn.Module): list of transformations
p (float): probability of applying the list of transforms
"""
_v1_transform_cls = _transforms.RandomApply
def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
super().__init__()
if not isinstance(transforms, (Sequence, nn.ModuleList)):
raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
self.transforms = transforms
if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
self.p = p
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return {"transforms": self.transforms, "p": self.p}
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if torch.rand(1) >= self.p:
return sample
for transform in self.transforms:
sample = transform(sample)
return sample
def extra_repr(self) -> str:
format_string = []
for t in self.transforms:
format_string.append(f" {t}")
return "\n".join(format_string)
class RandomChoice(Transform):
"""[BETA] Apply single transformation randomly picked from a list.
.. v2betastatus:: RandomChoice transform
This transform does not support torchscript.
Args:
transforms (sequence or torch.nn.Module): list of transformations
p (list of floats or None, optional): probability of each transform being picked.
If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
(default), all transforms have the same probability.
"""
def __init__(
self,
transforms: Sequence[Callable],
p: Optional[List[float]] = None,
) -> None:
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
if p is None:
p = [1] * len(transforms)
elif len(p) != len(transforms):
raise ValueError(f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}")
super().__init__()
self.transforms = transforms
total = sum(p)
self.p = [prob / total for prob in p]
def forward(self, *inputs: Any) -> Any:
idx = int(torch.multinomial(torch.tensor(self.p), 1))
transform = self.transforms[idx]
return transform(*inputs)
class RandomOrder(Transform):
"""[BETA] Apply a list of transformations in a random order.
.. v2betastatus:: RandomOrder transform
This transform does not support torchscript.
Args:
transforms (sequence or torch.nn.Module): list of transformations
"""
def __init__(self, transforms: Sequence[Callable]) -> None:
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
super().__init__()
self.transforms = transforms
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for idx in torch.randperm(len(self.transforms)):
transform = self.transforms[idx]
sample = transform(sample)
return sample
import warnings
from typing import Any, Dict, Union
import numpy as np
import PIL.Image
import torch
from transforms import functional as _F
from transforms.v2 import Transform
class ToTensor(Transform):
"""[BETA] Convert a PIL Image or ndarray to tensor and scale the values accordingly.
.. v2betastatus:: ToTensor transform
.. warning::
:class:`v2.ToTensor` is deprecated and will be removed in a future release.
Please use instead ``transforms.Compose([transforms.ToImageTensor(), transforms.ConvertImageDtype()])``.
This transform does not support torchscript.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8
In the other cases, tensors are returned without scaling.
.. note::
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
transforming target image masks. See the `references`_ for implementing the transforms for image masks.
.. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
"""
_transformed_types = (PIL.Image.Image, np.ndarray)
def __init__(self) -> None:
warnings.warn(
"The transform `ToTensor()` is deprecated and will be removed in a future release. "
"Instead, please use `transforms.Compose([transforms.ToImageTensor(), transforms.ConvertImageDtype()])`."
)
super().__init__()
def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor:
return _F.to_tensor(inpt)
import math
import numbers
import warnings
from typing import Any, cast, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union
import PIL.Image
import torch
from util import datapoints
import transforms as _transforms
from torchvision.ops.boxes import box_iou
from transforms.functional import _get_perspective_coeffs
from transforms.v2 import functional as F, Transform
from transforms import InterpolationMode
from transforms.v2.functional._geometry import _check_interpolation
from ._transform import _RandomApplyTransform
from ._utils import (
_check_padding_arg,
_check_padding_mode_arg,
_check_sequence_input,
_setup_angle,
_setup_fill_arg,
_setup_float_or_seq,
_setup_size,
)
from .utils import has_all, has_any, is_simple_tensor, query_bounding_box, query_spatial_size
class RandomHorizontalFlip(_RandomApplyTransform):
"""[BETA] Horizontally flip the input with a given probability.
.. v2betastatus:: RandomHorizontalFlip transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
Args:
p (float, optional): probability of the input being flipped. Default value is 0.5
"""
_v1_transform_cls = _transforms.RandomHorizontalFlip
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.horizontal_flip(inpt)
class RandomVerticalFlip(_RandomApplyTransform):
"""[BETA] Vertically flip the input with a given probability.
.. v2betastatus:: RandomVerticalFlip transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
Args:
p (float, optional): probability of the input being flipped. Default value is 0.5
"""
_v1_transform_cls = _transforms.RandomVerticalFlip
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.vertical_flip(inpt)
class Resize(Transform):
"""[BETA] Resize the input to the given size.
.. v2betastatus:: Resize transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
.. warning::
The output image might be different depending on its type: when downsampling, the interpolation of PIL images
and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
closer.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size).
.. note::
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then
the image is resized again so that the longer edge is equal to
``max_size``. As a result, ``size`` might be overruled, i.e. the
smaller edge may be shorter than ``size``. This is only supported
if ``size`` is an int (or a sequence of length 1 in torchscript
mode).
antialias (bool, optional): Whether to apply antialiasing.
It only affects **tensors** with bilinear or bicubic modes and it is
ignored otherwise: on PIL images, antialiasing is always applied on
bilinear or bicubic modes; on other modes (for PIL images and
tensors), antialiasing makes no sense and this parameter is ignored.
Possible values are:
- ``True``: will apply antialiasing for bilinear or bicubic modes.
Other mode aren't affected. This is probably what you want to use.
- ``False``: will not apply antialiasing for tensors on any mode. PIL
images are still antialiased on bilinear or bicubic modes, because
PIL doesn't support no antialias.
- ``None``: equivalent to ``False`` for tensors and ``True`` for
PIL images. This value exists for legacy reasons and you probably
don't want to use it unless you really know what you are doing.
The current default is ``None`` **but will change to** ``True`` **in
v0.17** for the PIL and Tensor backends to be consistent.
"""
_v1_transform_cls = _transforms.Resize
def __init__(
self,
size: Union[int, Sequence[int]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> None:
super().__init__()
if isinstance(size, int):
size = [size]
elif isinstance(size, (list, tuple)) and len(size) in {1, 2}:
size = list(size)
else:
raise ValueError(
f"size can either be an integer or a list or tuple of one or two integers, " f"but got {size} instead."
)
self.size = size
self.interpolation = _check_interpolation(interpolation)
self.max_size = max_size
self.antialias = antialias
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(
inpt,
self.size,
interpolation=self.interpolation,
max_size=self.max_size,
antialias=self.antialias,
)
class CenterCrop(Transform):
"""[BETA] Crop the input at the center.
.. v2betastatus:: CenterCrop transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
"""
_v1_transform_cls = _transforms.CenterCrop
def __init__(self, size: Union[int, Sequence[int]]):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.center_crop(inpt, output_size=self.size)
class RandomResizedCrop(Transform):
"""[BETA] Crop a random portion of the input and resize it to a given size.
.. v2betastatus:: RandomResizedCrop transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
A crop of the original input is made: the crop has a random area (H * W)
and a random aspect ratio. This crop is finally resized to the given
size. This is popularly used to train the Inception networks.
Args:
size (int or sequence): expected output size of the crop, for each edge. If size is an
int instead of sequence like (h, w), a square output size ``(size, size)`` is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
.. note::
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
scale (tuple of float, optional): Specifies the lower and upper bounds for the random area of the crop,
before resizing. The scale is defined with respect to the area of the original image.
ratio (tuple of float, optional): lower and upper bounds for the random aspect ratio of the crop, before
resizing.
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
antialias (bool, optional): Whether to apply antialiasing.
It only affects **tensors** with bilinear or bicubic modes and it is
ignored otherwise: on PIL images, antialiasing is always applied on
bilinear or bicubic modes; on other modes (for PIL images and
tensors), antialiasing makes no sense and this parameter is ignored.
Possible values are:
- ``True``: will apply antialiasing for bilinear or bicubic modes.
Other mode aren't affected. This is probably what you want to use.
- ``False``: will not apply antialiasing for tensors on any mode. PIL
images are still antialiased on bilinear or bicubic modes, because
PIL doesn't support no antialias.
- ``None``: equivalent to ``False`` for tensors and ``True`` for
PIL images. This value exists for legacy reasons and you probably
don't want to use it unless you really know what you are doing.
The current default is ``None`` **but will change to** ``True`` **in
v0.17** for the PIL and Tensor backends to be consistent.
"""
_v1_transform_cls = _transforms.RandomResizedCrop
def __init__(
self,
size: Union[int, Sequence[int]],
scale: Tuple[float, float] = (0.08, 1.0),
ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
if not isinstance(scale, Sequence):
raise TypeError("Scale should be a sequence")
scale = cast(Tuple[float, float], scale)
if not isinstance(ratio, Sequence):
raise TypeError("Ratio should be a sequence")
ratio = cast(Tuple[float, float], ratio)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
self.scale = scale
self.ratio = ratio
self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
area = height * width
log_ratio = self._log_ratio
for _ in range(10):
target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
)
).item()
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = torch.randint(0, height - h + 1, size=(1,)).item()
j = torch.randint(0, width - w + 1, size=(1,)).item()
break
else:
# Fallback to central crop
in_ratio = float(width) / float(height)
if in_ratio < min(self.ratio):
w = width
h = int(round(w / min(self.ratio)))
elif in_ratio > max(self.ratio):
h = height
w = int(round(h * max(self.ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return dict(top=i, left=j, height=h, width=w)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resized_crop(
inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias
)
ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]
class FiveCrop(Transform):
"""[BETA] Crop the image or video into four corners and the central crop.
.. v2betastatus:: FiveCrop transform
If the input is a :class:`torch.Tensor` or a :class:`~torchvision.datapoints.Image` or a
:class:`~torchvision.datapoints.Video` it can have arbitrary number of leading batch dimensions.
For example, the image can have ``[..., C, H, W]`` shape.
.. Note::
This transform returns a tuple of images and there may be a mismatch in the number of
inputs and targets your Dataset returns. See below for an example of how to deal with
this.
Args:
size (sequence or int): Desired output size of the crop. If size is an ``int``
instead of sequence like (h, w), a square crop of size (size, size) is made.
If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
Example:
>>> class BatchMultiCrop(transforms.Transform):
... def forward(self, sample: Tuple[Tuple[Union[datapoints.Image, datapoints.Video], ...], int]):
... images_or_videos, labels = sample
... batch_size = len(images_or_videos)
... image_or_video = images_or_videos[0]
... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos))
... labels = torch.full((batch_size,), label, device=images_or_videos.device)
... return images_or_videos, labels
...
>>> image = datapoints.Image(torch.rand(3, 256, 256))
>>> label = 3
>>> transform = transforms.Compose([transforms.FiveCrop(224), BatchMultiCrop()])
>>> images, labels = transform(image, label)
>>> images.shape
torch.Size([5, 3, 224, 224])
>>> labels
tensor([3, 3, 3, 3, 3])
"""
_v1_transform_cls = _transforms.FiveCrop
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(
self, inpt: ImageOrVideoTypeJIT, params: Dict[str, Any]
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
return F.five_crop(inpt, self.size)
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
class TenCrop(Transform):
"""[BETA] Crop the image or video into four corners and the central crop plus the flipped version of
these (horizontal flipping is used by default).
.. v2betastatus:: TenCrop transform
If the input is a :class:`torch.Tensor` or a :class:`~torchvision.datapoints.Image` or a
:class:`~torchvision.datapoints.Video` it can have arbitrary number of leading batch dimensions.
For example, the image can have ``[..., C, H, W]`` shape.
See :class:`~torchvision.transforms.v2.FiveCrop` for an example.
.. Note::
This transform returns a tuple of images and there may be a mismatch in the number of
inputs and targets your Dataset returns. See below for an example of how to deal with
this.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
vertical_flip (bool, optional): Use vertical flipping instead of horizontal
"""
_v1_transform_cls = _transforms.TenCrop
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
def _transform(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Tuple[
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
class Pad(Transform):
"""[BETA] Pad the input on all sides with the given "pad" value.
.. v2betastatus:: Pad transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
Args:
padding (int or sequence): Padding on each border. If a single int is provided this
is used to pad all borders. If sequence of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a sequence of length 4 is provided
this is the padding for the left, top, right and bottom borders respectively.
.. note::
In torchscript mode padding as single int is not supported, use a sequence of
length 1: ``[padding, ]``.
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is "constant".
- constant: pads with a constant value, this value is specified with fill
- edge: pads with the last value at the edge of the image.
- reflect: pads with reflection of image without repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
- symmetric: pads with reflection of image repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]
"""
_v1_transform_cls = _transforms.Pad
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform()
if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")
return params
def __init__(
self,
padding: Union[int, Sequence[int]],
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
_check_padding_arg(padding)
_check_padding_mode_arg(padding_mode)
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
self.padding = padding
self.fill = fill
self._fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
class RandomZoomOut(_RandomApplyTransform):
"""[BETA] "Zoom out" transformation from
`"SSD: Single Shot MultiBox Detector" <https://arxiv.org/abs/1512.02325>`_.
.. v2betastatus:: RandomZoomOut transform
This transformation randomly pads images, videos, bounding boxes and masks creating a zoom out effect.
Output spatial size is randomly sampled from original size up to a maximum size configured
with ``side_range`` parameter:
.. code-block:: python
r = uniform_sample(side_range[0], side_range[1])
output_width = input_width * r
output_height = input_height * r
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
Args:
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
side_range (sequence of floats, optional): tuple of two floats defines minimum and maximum factors to
scale the input size.
p (float, optional): probability of the input being flipped. Default value is 0.5
"""
def __init__(
self,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5,
) -> None:
super().__init__(p=p)
self.fill = fill
self._fill = _setup_fill_arg(fill)
_check_sequence_input(side_range, "side_range", req_sizes=(2,))
self.side_range = side_range
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(f"Invalid canvas side range provided {side_range}.")
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(flat_inputs)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
canvas_height = int(orig_h * r)
r = torch.rand(2)
left = int((canvas_width - orig_w) * r[0])
top = int((canvas_height - orig_h) * r[1])
right = canvas_width - (left + orig_w)
bottom = canvas_height - (top + orig_h)
padding = [left, top, right, bottom]
return dict(padding=padding)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
return F.pad(inpt, **params, fill=fill)
class RandomRotation(Transform):
"""[BETA] Rotate the input by angle.
.. v2betastatus:: RandomRotation transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
Args:
degrees (sequence or number): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees).
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
expand (bool, optional): Optional expansion flag.
If true, expands the output to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
Default is the center of the image.
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""
_v1_transform_cls = _transforms.RandomRotation
def __init__(
self,
degrees: Union[numbers.Number, Sequence],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
self.interpolation = _check_interpolation(interpolation)
self.expand = expand
self.fill = fill
self._fill = _setup_fill_arg(fill)
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
self.center = center
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
return dict(angle=angle)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
return F.rotate(
inpt,
**params,
interpolation=self.interpolation,
expand=self.expand,
center=self.center,
fill=fill,
)
class RandomAffine(Transform):
"""[BETA] Random affine transformation the input keeping center invariant.
.. v2betastatus:: RandomAffine transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
Args:
degrees (sequence or number): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees). Set to 0 to deactivate rotations.
translate (tuple, optional): tuple of maximum absolute fraction for horizontal
and vertical translations. For example translate=(a, b), then horizontal shift
is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
randomly sampled from the range a <= scale <= b. Will keep original scale by default.
shear (sequence or number, optional): Range of degrees to select from.
If shear is a number, a shear parallel to the x-axis in the range (-shear, +shear)
will be applied. Else if shear is a sequence of 2 values a shear parallel to the x-axis in the
range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
an x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
Will not apply shear by default.
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
Default is the center of the image.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""
_v1_transform_cls = _transforms.RandomAffine
def __init__(
self,
degrees: Union[numbers.Number, Sequence],
translate: Optional[Sequence[float]] = None,
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
if translate is not None:
_check_sequence_input(translate, "translate", req_sizes=(2,))
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
_check_sequence_input(scale, "scale", req_sizes=(2,))
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
else:
self.shear = shear
self.interpolation = _check_interpolation(interpolation)
self.fill = fill
self._fill = _setup_fill_arg(fill)
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
self.center = center
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
if self.translate is not None:
max_dx = float(self.translate[0] * width)
max_dy = float(self.translate[1] * height)
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
translate = (tx, ty)
else:
translate = (0, 0)
if self.scale is not None:
scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
else:
scale = 1.0
shear_x = shear_y = 0.0
if self.shear is not None:
shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()
if len(self.shear) == 4:
shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()
shear = (shear_x, shear_y)
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
return F.affine(
inpt,
**params,
interpolation=self.interpolation,
fill=fill,
center=self.center,
)
class RandomCrop(Transform):
"""[BETA] Crop the input at a random location.
.. v2betastatus:: RandomCrop transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
padding (int or sequence, optional): Optional padding on each border
of the image. Default is None. If a single int is provided this
is used to pad all borders. If sequence of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a sequence of length 4 is provided
this is the padding for the left, top, right and bottom borders respectively.
.. note::
In torchscript mode padding as single int is not supported, use a sequence of
length 1: ``[padding, ]``.
pad_if_needed (boolean, optional): It will pad the image if smaller than the
desired size to avoid raising an exception. Since cropping is done
after padding, the padding seems to be done at a random offset.
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is constant.
- constant: pads with a constant value, this value is specified with fill
- edge: pads with the last value at the edge of the image.
- reflect: pads with reflection of image without repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
- symmetric: pads with reflection of image repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]
"""
_v1_transform_cls = _transforms.RandomCrop
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform()
if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")
padding = self.padding
if padding is not None:
pad_left, pad_right, pad_top, pad_bottom = padding
padding = [pad_left, pad_top, pad_right, pad_bottom]
params["padding"] = padding
return params
def __init__(
self,
size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
if pad_if_needed or padding is not None:
if padding is not None:
_check_padding_arg(padding)
_check_padding_mode_arg(padding_mode)
self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type]
self.pad_if_needed = pad_if_needed
self.fill = fill
self._fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
padded_height, padded_width = query_spatial_size(flat_inputs)
if self.padding is not None:
pad_left, pad_right, pad_top, pad_bottom = self.padding
padded_height += pad_top + pad_bottom
padded_width += pad_left + pad_right
else:
pad_left = pad_right = pad_top = pad_bottom = 0
cropped_height, cropped_width = self.size
if self.pad_if_needed:
if padded_height < cropped_height:
diff = cropped_height - padded_height
pad_top += diff
pad_bottom += diff
padded_height += 2 * diff
if padded_width < cropped_width:
diff = cropped_width - padded_width
pad_left += diff
pad_right += diff
padded_width += 2 * diff
if padded_height < cropped_height or padded_width < cropped_width:
raise ValueError(
f"Required crop size {(cropped_height, cropped_width)} is larger than "
f"{'padded ' if self.padding is not None else ''}input image size {(padded_height, padded_width)}."
)
# We need a different order here than we have in self.padding since this padding will be parsed again in `F.pad`
padding = [pad_left, pad_top, pad_right, pad_bottom]
needs_pad = any(padding)
needs_vert_crop, top = (
(True, int(torch.randint(0, padded_height - cropped_height + 1, size=())))
if padded_height > cropped_height
else (False, 0)
)
needs_horz_crop, left = (
(True, int(torch.randint(0, padded_width - cropped_width + 1, size=())))
if padded_width > cropped_width
else (False, 0)
)
return dict(
needs_crop=needs_vert_crop or needs_horz_crop,
top=top,
left=left,
height=cropped_height,
width=cropped_width,
needs_pad=needs_pad,
padding=padding,
)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_pad"]:
fill = self._fill[type(inpt)]
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
if params["needs_crop"]:
inpt = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
return inpt
class RandomPerspective(_RandomApplyTransform):
"""[BETA] Perform a random perspective transformation of the input with a given probability.
.. v2betastatus:: RandomPerspective transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
Args:
distortion_scale (float, optional): argument to control the degree of distortion and ranges from 0 to 1.
Default is 0.5.
p (float, optional): probability of the input being transformed. Default is 0.5.
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
"""
_v1_transform_cls = _transforms.RandomPerspective
def __init__(
self,
distortion_scale: float = 0.5,
p: float = 0.5,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
) -> None:
super().__init__(p=p)
if not (0 <= distortion_scale <= 1):
raise ValueError("Argument distortion_scale value should be between 0 and 1")
self.distortion_scale = distortion_scale
self.interpolation = _check_interpolation(interpolation)
self.fill = fill
self._fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
distortion_scale = self.distortion_scale
half_height = height // 2
half_width = width // 2
bound_height = int(distortion_scale * half_height) + 1
bound_width = int(distortion_scale * half_width) + 1
topleft = [
int(torch.randint(0, bound_width, size=(1,))),
int(torch.randint(0, bound_height, size=(1,))),
]
topright = [
int(torch.randint(width - bound_width, width, size=(1,))),
int(torch.randint(0, bound_height, size=(1,))),
]
botright = [
int(torch.randint(width - bound_width, width, size=(1,))),
int(torch.randint(height - bound_height, height, size=(1,))),
]
botleft = [
int(torch.randint(0, bound_width, size=(1,))),
int(torch.randint(height - bound_height, height, size=(1,))),
]
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
endpoints = [topleft, topright, botright, botleft]
perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints)
return dict(coefficients=perspective_coeffs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
return F.perspective(
inpt,
None,
None,
fill=fill,
interpolation=self.interpolation,
**params,
)
class ElasticTransform(Transform):
"""[BETA] Transform the input with elastic transformations.
.. v2betastatus:: RandomPerspective transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
Given alpha and sigma, it will generate displacement
vectors for all pixels based on random offsets. Alpha controls the strength
and sigma controls the smoothness of the displacements.
The displacements are added to an identity grid and the resulting grid is
used to transform the input.
.. note::
Implementation to transform bounding boxes is approximative (not exact).
We construct an approximation of the inverse grid as ``inverse_grid = idenity - displacement``.
This is not an exact inverse of the grid used to transform images, i.e. ``grid = identity + displacement``.
Our assumption is that ``displacement * displacement`` is small and can be ignored.
Large displacements would lead to large errors in the approximation.
Applications:
Randomly transforms the morphology of objects in images and produces a
see-through-water-like effect.
Args:
alpha (float or sequence of floats, optional): Magnitude of displacements. Default is 50.0.
sigma (float or sequence of floats, optional): Smoothness of displacements. Default is 5.0.
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={datapoints.Image: 127, datapoints.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
"""
_v1_transform_cls = _transforms.ElasticTransform
def __init__(
self,
alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
) -> None:
super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
self.interpolation = _check_interpolation(interpolation)
self.fill = fill
self._fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
size = list(query_spatial_size(flat_inputs))
dx = torch.rand([1, 1] + size) * 2 - 1
if self.sigma[0] > 0.0:
kx = int(8 * self.sigma[0] + 1)
# if kernel size is even we have to make it odd
if kx % 2 == 0:
kx += 1
dx = F.gaussian_blur(dx, [kx, kx], list(self.sigma))
dx = dx * self.alpha[0] / size[0]
dy = torch.rand([1, 1] + size) * 2 - 1
if self.sigma[1] > 0.0:
ky = int(8 * self.sigma[1] + 1)
# if kernel size is even we have to make it odd
if ky % 2 == 0:
ky += 1
dy = F.gaussian_blur(dy, [ky, ky], list(self.sigma))
dy = dy * self.alpha[1] / size[1]
displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
return dict(displacement=displacement)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
return F.elastic(
inpt,
**params,
fill=fill,
interpolation=self.interpolation,
)
class RandomIoUCrop(Transform):
"""[BETA] Random IoU crop transformation from
`"SSD: Single Shot MultiBox Detector" <https://arxiv.org/abs/1512.02325>`_.
.. v2betastatus:: RandomIoUCrop transform
This transformation requires an image or video data and ``datapoints.BoundingBox`` in the input.
.. warning::
In order to properly remove the bounding boxes below the IoU threshold, `RandomIoUCrop`
must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBox`, either immediately
after or later in the transforms pipeline.
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
Args:
min_scale (float, optional): Minimum factors to scale the input size.
max_scale (float, optional): Maximum factors to scale the input size.
min_aspect_ratio (float, optional): Minimum aspect ratio for the cropped image or video.
max_aspect_ratio (float, optional): Maximum aspect ratio for the cropped image or video.
sampler_options (list of float, optional): List of minimal IoU (Jaccard) overlap between all the boxes and
a cropped image or video. Default, ``None`` which corresponds to ``[0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]``
trials (int, optional): Number of trials to find a crop for a given value of minimal IoU (Jaccard) overlap.
Default, 40.
"""
def __init__(
self,
min_scale: float = 0.3,
max_scale: float = 1.0,
min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2.0,
sampler_options: Optional[List[float]] = None,
trials: int = 40,
):
super().__init__()
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
self.min_scale = min_scale
self.max_scale = max_scale
self.min_aspect_ratio = min_aspect_ratio
self.max_aspect_ratio = max_aspect_ratio
if sampler_options is None:
sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
self.options = sampler_options
self.trials = trials
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not (
has_all(flat_inputs, datapoints.BoundingBox)
and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_simple_tensor)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain tensor or PIL images "
"and bounding boxes. Sample can also contain masks."
)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(flat_inputs)
bboxes = query_bounding_box(flat_inputs)
while True:
# sample an option
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
min_jaccard_overlap = self.options[idx]
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
return dict()
for _ in range(self.trials):
# check the aspect ratio limitations
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
new_w = int(orig_w * r[0])
new_h = int(orig_h * r[1])
aspect_ratio = new_w / new_h
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
continue
# check for 0 area crops
r = torch.rand(2)
left = int((orig_w - new_w) * r[0])
top = int((orig_h - new_h) * r[1])
right = left + new_w
bottom = top + new_h
if left == right or top == bottom:
continue
# check for any valid boxes with centers within the crop area
xyxy_bboxes = F.convert_format_bounding_box(
bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY
)
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
if not is_within_crop_area.any():
continue
# check at least 1 box with jaccard limitations
xyxy_bboxes = xyxy_bboxes[is_within_crop_area]
ious = box_iou(
xyxy_bboxes,
torch.tensor([[left, top, right, bottom]], dtype=xyxy_bboxes.dtype, device=xyxy_bboxes.device),
)
if ious.max() < min_jaccard_overlap:
continue
return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if len(params) < 1:
return inpt
output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
if isinstance(output, datapoints.BoundingBox):
# We "mark" the invalid boxes as degenreate, and they can be
# removed by a later call to SanitizeBoundingBox()
output[~params["is_within_crop_area"]] = 0
return output
class ScaleJitter(Transform):
"""[BETA] Perform Large Scale Jitter on the input according to
`"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" <https://arxiv.org/abs/2012.07177>`_.
.. v2betastatus:: ScaleJitter transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
Args:
target_size (tuple of int): Target size. This parameter defines base scale for jittering,
e.g. ``min(target_size[0] / width, target_size[1] / height)``.
scale_range (tuple of float, optional): Minimum and maximum of the scale range. Default, ``(0.1, 2.0)``.
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
antialias (bool, optional): Whether to apply antialiasing.
It only affects **tensors** with bilinear or bicubic modes and it is
ignored otherwise: on PIL images, antialiasing is always applied on
bilinear or bicubic modes; on other modes (for PIL images and
tensors), antialiasing makes no sense and this parameter is ignored.
Possible values are:
- ``True``: will apply antialiasing for bilinear or bicubic modes.
Other mode aren't affected. This is probably what you want to use.
- ``False``: will not apply antialiasing for tensors on any mode. PIL
images are still antialiased on bilinear or bicubic modes, because
PIL doesn't support no antialias.
- ``None``: equivalent to ``False`` for tensors and ``True`` for
PIL images. This value exists for legacy reasons and you probably
don't want to use it unless you really know what you are doing.
The current default is ``None`` **but will change to** ``True`` **in
v0.17** for the PIL and Tensor backends to be consistent.
"""
def __init__(
self,
target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
):
super().__init__()
self.target_size = target_size
self.scale_range = scale_range
self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_height, orig_width = query_spatial_size(flat_inputs)
scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
new_width = int(orig_width * r)
new_height = int(orig_height * r)
return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias)
class RandomShortestSize(Transform):
"""[BETA] Randomly resize the input.
.. v2betastatus:: RandomShortestSize transform
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
Args:
min_size (int or sequence of int): Minimum spatial size. Single integer value or a sequence of integer values.
max_size (int, optional): Maximum spatial size. Default, None.
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
antialias (bool, optional): Whether to apply antialiasing.
It only affects **tensors** with bilinear or bicubic modes and it is
ignored otherwise: on PIL images, antialiasing is always applied on
bilinear or bicubic modes; on other modes (for PIL images and
tensors), antialiasing makes no sense and this parameter is ignored.
Possible values are:
- ``True``: will apply antialiasing for bilinear or bicubic modes.
Other mode aren't affected. This is probably what you want to use.
- ``False``: will not apply antialiasing for tensors on any mode. PIL
images are still antialiased on bilinear or bicubic modes, because
PIL doesn't support no antialias.
- ``None``: equivalent to ``False`` for tensors and ``True`` for
PIL images. This value exists for legacy reasons and you probably
don't want to use it unless you really know what you are doing.
The current default is ``None`` **but will change to** ``True`` **in
v0.17** for the PIL and Tensor backends to be consistent.
"""
def __init__(
self,
min_size: Union[List[int], Tuple[int], int],
max_size: Optional[int] = None,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
):
super().__init__()
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
self.max_size = max_size
self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_height, orig_width = query_spatial_size(flat_inputs)
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
r = min_size / min(orig_height, orig_width)
if self.max_size is not None:
r = min(r, self.max_size / max(orig_height, orig_width))
new_width = int(orig_width * r)
new_height = int(orig_height * r)
return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias)
class RandomResize(Transform):
"""[BETA] Randomly resize the input.
.. v2betastatus:: RandomResize transform
This transformation can be used together with ``RandomCrop`` as data augmentations to train
models on image segmentation task.
Output spatial size is randomly sampled from the interval ``[min_size, max_size]``:
.. code-block:: python
size = uniform_sample(min_size, max_size)
output_width = size
output_height = size
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
:class:`~torchvision.datapoints.Video`, :class:`~torchvision.datapoints.BoundingBox` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
Args:
min_size (int): Minimum output size for random sampling
max_size (int): Maximum output size for random sampling
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
antialias (bool, optional): Whether to apply antialiasing.
It only affects **tensors** with bilinear or bicubic modes and it is
ignored otherwise: on PIL images, antialiasing is always applied on
bilinear or bicubic modes; on other modes (for PIL images and
tensors), antialiasing makes no sense and this parameter is ignored.
Possible values are:
- ``True``: will apply antialiasing for bilinear or bicubic modes.
Other mode aren't affected. This is probably what you want to use.
- ``False``: will not apply antialiasing for tensors on any mode. PIL
images are still antialiased on bilinear or bicubic modes, because
PIL doesn't support no antialias.
- ``None``: equivalent to ``False`` for tensors and ``True`` for
PIL images. This value exists for legacy reasons and you probably
don't want to use it unless you really know what you are doing.
The current default is ``None`` **but will change to** ``True`` **in
v0.17** for the PIL and Tensor backends to be consistent.
"""
def __init__(
self,
min_size: int,
max_size: int,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> None:
super().__init__()
self.min_size = min_size
self.max_size = max_size
self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
size = int(torch.randint(self.min_size, self.max_size, ()))
return dict(size=[size])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias)
from typing import Any, Dict, Union
import torch
from util import datapoints
import transforms as _transforms
from transforms.v2 import functional as F, Transform
from .utils import is_simple_tensor
class ConvertBoundingBoxFormat(Transform):
"""[BETA] Convert bounding box coordinates to the given ``format``, eg from "CXCYWH" to "XYXY".
.. v2betastatus:: ConvertBoundingBoxFormat transform
Args:
format (str or datapoints.BoundingBoxFormat): output bounding box format.
Possible values are defined by :class:`~torchvision.datapoints.BoundingBoxFormat` and
string values match the enums, e.g. "XYXY" or "XYWH" etc.
"""
_transformed_types = (datapoints.BoundingBox,)
def __init__(self, format: Union[str, datapoints.BoundingBoxFormat]) -> None:
super().__init__()
if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format]
self.format = format
def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox:
return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value]
class ConvertDtype(Transform):
"""[BETA] Convert input image or video to the given ``dtype`` and scale the values accordingly.
.. v2betastatus:: ConvertDtype transform
This function does not support PIL Image.
Args:
dtype (torch.dtype): Desired data type of the output
.. note::
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.
Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
_v1_transform_cls = _transforms.ConvertImageDtype
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
self.dtype = dtype
def _transform(
self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> Union[datapoints._TensorImageType, datapoints._TensorVideoType]:
return F.convert_dtype(inpt, self.dtype)
# We changed the name to align it with the new naming scheme. Still, `ConvertImageDtype` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
ConvertImageDtype = ConvertDtype
class ClampBoundingBox(Transform):
"""[BETA] Clamp bounding boxes to their corresponding image dimensions.
The clamping is done according to the bounding boxes' ``spatial_size`` meta-data.
.. v2betastatus:: ClampBoundingBox transform
"""
_transformed_types = (datapoints.BoundingBox,)
def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox:
return F.clamp_bounding_box(inpt) # type: ignore[return-value]
import collections
import warnings
from contextlib import suppress
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Sequence, Type, Union
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from util import datapoints
import transforms as _transforms
from transforms.v2 import functional as F, Transform
from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size
from .utils import has_any, is_simple_tensor, query_bounding_box
# TODO: do we want/need to expose this?
class Identity(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt
class Lambda(Transform):
"""[BETA] Apply a user-defined function as a transform.
.. v2betastatus:: Lambda transform
This transform does not support torchscript.
Args:
lambd (function): Lambda/function to be used for transform.
"""
def __init__(self, lambd: Callable[[Any], Any], *types: Type):
super().__init__()
self.lambd = lambd
self.types = types or (object,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, self.types):
return self.lambd(inpt)
else:
return inpt
def extra_repr(self) -> str:
extras = []
name = getattr(self.lambd, "__name__", None)
if name:
extras.append(name)
extras.append(f"types={[type.__name__ for type in self.types]}")
return ", ".join(extras)
class LinearTransformation(Transform):
"""[BETA] Transform a tensor image or video with a square transformation matrix and a mean_vector computed offline.
.. v2betastatus:: LinearTransformation transform
This transform does not support PIL Image.
Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
subtract mean_vector from it which is then followed by computing the dot
product with the transformation matrix and then reshaping the tensor to its
original shape.
Applications:
whitening transformation: Suppose X is a column vector zero-centered data.
Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
perform SVD on this matrix and pass it as transformation_matrix.
Args:
transformation_matrix (Tensor): tensor [D x D], D = C x H x W
mean_vector (Tensor): tensor [D], D = C x H x W
"""
_v1_transform_cls = _transforms.LinearTransformation
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__()
if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError(
"transformation_matrix should be square. Got "
f"{tuple(transformation_matrix.size())} rectangular matrix."
)
if mean_vector.size(0) != transformation_matrix.size(0):
raise ValueError(
f"mean_vector should have the same length {mean_vector.size(0)}"
f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
)
if transformation_matrix.device != mean_vector.device:
raise ValueError(
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
)
if transformation_matrix.dtype != mean_vector.dtype:
raise ValueError(
f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
)
self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector
def _check_inputs(self, sample: Any) -> Any:
if has_any(sample, PIL.Image.Image):
raise TypeError("LinearTransformation does not work on PIL Images")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
shape = inpt.shape
n = shape[-3] * shape[-2] * shape[-1]
if n != self.transformation_matrix.shape[0]:
raise ValueError(
"Input tensor and transformation matrix have incompatible shape."
+ f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
+ f"{self.transformation_matrix.shape[0]}"
)
if inpt.device.type != self.mean_vector.device.type:
raise ValueError(
"Input tensor should be on the same device as transformation matrix and mean vector. "
f"Got {inpt.device} vs {self.mean_vector.device}"
)
flat_inpt = inpt.reshape(-1, n) - self.mean_vector
transformation_matrix = self.transformation_matrix.to(flat_inpt.dtype)
output = torch.mm(flat_inpt, transformation_matrix)
output = output.reshape(shape)
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
return output
class Normalize(Transform):
"""[BETA] Normalize a tensor image or video with mean and standard deviation.
.. v2betastatus:: Normalize transform
This transform does not support PIL Image.
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
channels, this transform will normalize each channel of the input
``torch.*Tensor`` i.e.,
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
.. note::
This transform acts out of place, i.e., it does not mutate the input tensor.
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation in-place.
"""
_v1_transform_cls = _transforms.Normalize
_transformed_types = (datapoints.Image, is_simple_tensor, datapoints.Video)
def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
super().__init__()
self.mean = list(mean)
self.std = list(std)
self.inplace = inplace
def _check_inputs(self, sample: Any) -> Any:
if has_any(sample, PIL.Image.Image):
raise TypeError(f"{type(self).__name__}() does not support PIL images.")
def _transform(
self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> Any:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
class GaussianBlur(Transform):
"""[BETA] Blurs image with randomly chosen Gaussian blur.
.. v2betastatus:: GausssianBlur transform
If the input is a Tensor, it is expected
to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
kernel_size (int or sequence): Size of the Gaussian kernel.
sigma (float or tuple of float (min, max)): Standard deviation to be used for
creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
of float (min, max), sigma is chosen uniformly at random to lie in the
given range.
"""
_v1_transform_cls = _transforms.GaussianBlur
def __init__(
self, kernel_size: Union[int, Sequence[int]], sigma: Union[int, float, Sequence[float]] = (0.1, 2.0)
) -> None:
super().__init__()
self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
for ks in self.kernel_size:
if ks <= 0 or ks % 2 == 0:
raise ValueError("Kernel size value should be an odd and positive number.")
if isinstance(sigma, (int, float)):
if sigma <= 0:
raise ValueError("If sigma is a single number, it must be positive.")
sigma = float(sigma)
elif isinstance(sigma, Sequence) and len(sigma) == 2:
if not 0.0 < sigma[0] <= sigma[1]:
raise ValueError("sigma values should be positive and of the form (min, max).")
else:
raise TypeError("sigma should be a single int or float or a list/tuple with length 2 floats.")
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item()
return dict(sigma=[sigma, sigma])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.gaussian_blur(inpt, self.kernel_size, **params)
class ToDtype(Transform):
"""[BETA] Converts the input to a specific dtype - this does not scale values.
.. v2betastatus:: ToDtype transform
Args:
dtype (``torch.dtype`` or dict of ``Datapoint`` -> ``torch.dtype``): The dtype to convert to.
A dict can be passed to specify per-datapoint conversions, e.g.
``dtype={datapoints.Image: torch.float32, datapoints.Video:
torch.float64}``.
"""
_transformed_types = (torch.Tensor,)
def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None:
super().__init__()
if not isinstance(dtype, dict):
dtype = _get_defaultdict(dtype)
if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]):
warnings.warn(
"Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
dtype = self.dtype[type(inpt)]
if dtype is None:
return inpt
return inpt.to(dtype=dtype)
class SanitizeBoundingBox(Transform):
"""[BETA] Remove degenerate/invalid bounding boxes and their corresponding labels and masks.
.. v2betastatus:: SanitizeBoundingBox transform
This transform removes bounding boxes and their associated labels/masks that:
- are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- have any coordinate outside of their corresponding image. You may want to
call :class:`~torchvision.transforms.v2.ClampBoundingBox` first to avoid undesired removals.
It is recommended to call it at the end of a pipeline, before passing the
input to the models. It is critical to call this transform if
:class:`~torchvision.transforms.v2.RandomIoUCrop` was called.
If you want to be extra careful, you may call it after all transforms that
may modify bounding boxes but once at the end should be enough in most
cases.
Args:
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input.
It can be a str in which case the input is expected to be a dict, and ``labels_getter`` then specifies
the key whose value corresponds to the labels. It can also be a callable that takes the same input
as the transform, and returns the labels.
By default, this will try to find a "labels" key in the input, if
the input is a dict or it is a tuple whose second element is a dict.
This heuristic should work well with a lot of datasets, including the built-in torchvision datasets.
"""
def __init__(
self,
min_size: float = 1.0,
labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default",
) -> None:
super().__init__()
if min_size < 1:
raise ValueError(f"min_size must be >= 1, got {min_size}.")
self.min_size = min_size
self.labels_getter = labels_getter
self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]]
if labels_getter == "default":
self._labels_getter = self._find_labels_default_heuristic
elif callable(labels_getter):
self._labels_getter = labels_getter
elif isinstance(labels_getter, str):
self._labels_getter = lambda inputs: SanitizeBoundingBox._get_dict_or_second_tuple_entry(inputs)[
labels_getter # type: ignore[index]
]
elif labels_getter is None:
self._labels_getter = None
else:
raise ValueError(
"labels_getter should either be a str, callable, or 'default'. "
f"Got {labels_getter} of type {type(labels_getter)}."
)
@staticmethod
def _get_dict_or_second_tuple_entry(inputs: Any) -> Mapping[str, Any]:
# datasets outputs may be plain dicts like {"img": ..., "labels": ..., "bbox": ...}
# or tuples like (img, {"labels":..., "bbox": ...})
# This hacky helper accounts for both structures.
if isinstance(inputs, tuple):
inputs = inputs[1]
if not isinstance(inputs, collections.abc.Mapping):
raise ValueError(
f"If labels_getter is a str or 'default', "
f"then the input to forward() must be a dict or a tuple whose second element is a dict."
f" Got {type(inputs)} instead."
)
return inputs
@staticmethod
def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
# Tries to find a "labels" key, otherwise tries for the first key that contains "label" - case insensitive
# Returns None if nothing is found
inputs = SanitizeBoundingBox._get_dict_or_second_tuple_entry(inputs)
candidate_key = None
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
if candidate_key is None:
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if "label" in key.lower())
if candidate_key is None:
raise ValueError(
"Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?"
"If there are no samples and it is by design, pass labels_getter=None."
)
return inputs[candidate_key]
def forward(self, *inputs: Any) -> Any:
inputs = inputs if len(inputs) > 1 else inputs[0]
if self._labels_getter is None:
labels = None
else:
labels = self._labels_getter(inputs)
if labels is not None and not isinstance(labels, torch.Tensor):
raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.")
flat_inputs, spec = tree_flatten(inputs)
# TODO: this enforces one single BoundingBox entry.
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes...
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this?
boxes = query_bounding_box(flat_inputs)
if boxes.ndim != 2:
raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}")
if labels is not None and boxes.shape[0] != labels.shape[0]:
raise ValueError(
f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match."
)
boxes = cast(
datapoints.BoundingBox,
F.convert_format_bounding_box(
boxes,
new_format=datapoints.BoundingBoxFormat.XYXY,
),
)
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
image_h, image_w = boxes.spatial_size
valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)
params = dict(valid=valid, labels=labels)
flat_outputs = [
# Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxes and the labels
self._transform(inpt, params)
for inpt in flat_inputs
]
return tree_unflatten(flat_outputs, spec)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
is_label = inpt is not None and inpt is params["labels"]
is_bounding_box_or_mask = isinstance(inpt, (datapoints.BoundingBox, datapoints.Mask))
if not (is_label or is_bounding_box_or_mask):
return inpt
output = inpt[params["valid"]]
if is_label:
return output
return type(inpt).wrap_like(inpt, output)
from typing import Any, Dict
from util import datapoints
from transforms.v2 import functional as F, Transform
from transforms.v2.utils import is_simple_tensor
class UniformTemporalSubsample(Transform):
"""[BETA] Uniformly subsample ``num_samples`` indices from the temporal dimension of the video.
.. v2betastatus:: UniformTemporalSubsample transform
Videos are expected to be of shape ``[..., T, C, H, W]`` where ``T`` denotes the temporal dimension.
When ``num_samples`` is larger than the size of temporal dimension of the video, it
will sample frames based on nearest neighbor interpolation.
Args:
num_samples (int): The number of equispaced samples to be selected
"""
_transformed_types = (is_simple_tensor, datapoints.Video)
def __init__(self, num_samples: int):
super().__init__()
self.num_samples = num_samples
def _transform(self, inpt: datapoints._VideoType, params: Dict[str, Any]) -> datapoints._VideoType:
return F.uniform_temporal_subsample(inpt, self.num_samples)
from __future__ import annotations
import enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import PIL.Image
import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from util import datapoints
from transforms.v2.utils import check_type, has_any, is_simple_tensor
class Transform(nn.Module):
# Class attribute defining transformed types. Other types are passed-through without any transformation
# We support both Types and callables that are able to do further checks on the type of the input.
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image)
def __init__(self) -> None:
super().__init__()
def _check_inputs(self, flat_inputs: List[Any]) -> None:
pass
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict()
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
raise NotImplementedError
def forward(self, *inputs: Any) -> Any:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
self._check_inputs(flat_inputs)
needs_transform_list = self._needs_transform_list(flat_inputs)
params = self._get_params(
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
)
flat_outputs = [
self._transform(inpt, params) if needs_transform else inpt
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
]
return tree_unflatten(flat_outputs, spec)
def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]:
# Below is a heuristic on how to deal with simple tensor inputs:
# 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
# (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample.
# 2. If there is no explicit image or video in the sample, only the first encountered simple tensor is
# transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
# of `tree_flatten`, which recurses depth-first through the input.
#
# This heuristic stems from two requirements:
# 1. We need to keep BC for single input simple tensors and treat them as images.
# 2. We don't want to treat all simple tensors as images, because some datasets like `CelebA` or `Widerface`
# return supplemental numerical data as tensors that cannot be transformed as images.
#
# The heuristic should work well for most people in practice. The only case where it doesn't is if someone
# tries to transform multiple simple tensors at the same time, expecting them all to be treated as images.
# However, this case wasn't supported by transforms v1 either, so there is no BC concern.
needs_transform_list = []
transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image)
for inpt in flat_inputs:
needs_transform = True
if not check_type(inpt, self._transformed_types):
needs_transform = False
elif is_simple_tensor(inpt):
if transform_simple_tensor:
transform_simple_tensor = False
else:
needs_transform = False
needs_transform_list.append(needs_transform)
return needs_transform_list
def extra_repr(self) -> str:
extra = []
for name, value in self.__dict__.items():
if name.startswith("_") or name == "training":
continue
if not isinstance(value, (bool, int, float, str, tuple, list, enum.Enum)):
continue
extra.append(f"{name}={value}")
return ", ".join(extra)
# This attribute should be set on all transforms that have a v1 equivalent. Doing so enables two things:
# 1. In case the v1 transform has a static `get_params` method, it will also be available under the same name on
# the v2 transform. See `__init_subclass__` for details.
# 2. The v2 transform will be JIT scriptable. See `_extract_params_for_v1_transform` and `__prepare_scriptable__`
# for details.
_v1_transform_cls: Optional[Type[nn.Module]] = None
def __init_subclass__(cls) -> None:
# Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance.
# This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`.
if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"):
cls.get_params = staticmethod(cls._v1_transform_cls.get_params) # type: ignore[attr-defined]
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
# v2 transform instance. It extracts all available public attributes that are specific to that transform and
# not `nn.Module` in general.
# Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen
# if the v2 transform introduced new parameters that are not support by the v1 transform.
common_attrs = nn.Module().__dict__.keys()
return {
attr: value
for attr, value in self.__dict__.items()
if not attr.startswith("_") and attr not in common_attrs
}
def __prepare_scriptable__(self) -> nn.Module:
# This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return
# value is used for scripting over the original object that should have been scripted. Since the v1 transforms
# are JIT scriptable, and we made sure that for single image inputs v1 and v2 are equivalent, we just return the
# equivalent v1 transform here. This of course only makes transforms v2 JIT scriptable as long as transforms v1
# is around.
if self._v1_transform_cls is None:
raise RuntimeError(
f"Transform {type(self).__name__} cannot be JIT scripted. "
"torchscript is only supported for backward compatibility with transforms "
"which are already in torchvision.transforms. "
"For torchscript support (on tensors only), you can use the functional API instead."
)
return self._v1_transform_cls(**self._extract_params_for_v1_transform())
class _RandomApplyTransform(Transform):
def __init__(self, p: float = 0.5) -> None:
if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
super().__init__()
self.p = p
def forward(self, *inputs: Any) -> Any:
# We need to almost duplicate `Transform.forward()` here since we always want to check the inputs, but return
# early afterwards in case the random check triggers. The same result could be achieved by calling
# `super().forward()` after the random check, but that would call `self._check_inputs` twice.
inputs = inputs if len(inputs) > 1 else inputs[0]
flat_inputs, spec = tree_flatten(inputs)
self._check_inputs(flat_inputs)
if torch.rand(1) >= self.p:
return inputs
needs_transform_list = self._needs_transform_list(flat_inputs)
params = self._get_params(
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
)
flat_outputs = [
self._transform(inpt, params) if needs_transform else inpt
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
]
return tree_unflatten(flat_outputs, spec)
from typing import Any, Dict, Optional, Union
import numpy as np
import PIL.Image
import torch
from util import datapoints
from transforms.v2 import functional as F, Transform
from transforms.v2.utils import is_simple_tensor
class PILToTensor(Transform):
"""[BETA] Convert a PIL Image to a tensor of the same type - this does not scale values.
.. v2betastatus:: PILToTensor transform
This transform does not support torchscript.
Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
"""
_transformed_types = (PIL.Image.Image,)
def _transform(self, inpt: PIL.Image.Image, params: Dict[str, Any]) -> torch.Tensor:
return F.pil_to_tensor(inpt)
class ToImageTensor(Transform):
"""[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.datapoints.Image`
; this does not scale values.
.. v2betastatus:: ToImageTensor transform
This transform does not support torchscript.
"""
_transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray)
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> datapoints.Image:
return F.to_image_tensor(inpt)
class ToImagePIL(Transform):
"""[BETA] Convert a tensor or an ndarray to PIL Image - this does not scale values.
.. v2betastatus:: ToImagePIL transform
This transform does not support torchscript.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL Image while preserving the value range.
Args:
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
- If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
- If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
- If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
- If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
``short``).
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
"""
_transformed_types = (is_simple_tensor, datapoints.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None:
super().__init__()
self.mode = mode
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> PIL.Image.Image:
return F.to_image_pil(inpt, mode=self.mode)
# We changed the name to align them with the new naming scheme. Still, `ToPILImage` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
ToPILImage = ToImagePIL
import functools
import numbers
from collections import defaultdict
from typing import Any, Dict, Literal, Sequence, Type, TypeVar, Union
from util import datapoints
from util.datapoints import _FillType, _FillTypeJIT
from transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]:
if not isinstance(arg, (float, Sequence)):
raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}")
if isinstance(arg, Sequence) and len(arg) != req_size:
raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}")
if isinstance(arg, Sequence):
for element in arg:
if not isinstance(element, float):
raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}")
if isinstance(arg, float):
arg = [float(arg), float(arg)]
if isinstance(arg, (list, tuple)) and len(arg) == 1:
arg = [arg[0], arg[0]]
return arg
def _check_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> None:
if isinstance(fill, dict):
for key, value in fill.items():
# Check key for type
_check_fill_arg(value)
if isinstance(fill, defaultdict) and callable(fill.default_factory):
default_value = fill.default_factory()
_check_fill_arg(default_value)
else:
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.")
T = TypeVar("T")
def _default_arg(value: T) -> T:
return value
def _get_defaultdict(default: T) -> Dict[Any, T]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return defaultdict(functools.partial(_default_arg, default))
def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# if fill is None:
# fill = 0
if fill is None:
return fill
if not isinstance(fill, (int, float)):
fill = [float(v) for v in list(fill)]
return fill # type: ignore[return-value]
def _setup_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> Dict[Type, _FillTypeJIT]:
_check_fill_arg(fill)
if isinstance(fill, dict):
for k, v in fill.items():
fill[k] = _convert_fill_arg(v)
if isinstance(fill, defaultdict) and callable(fill.default_factory):
default_value = fill.default_factory()
sanitized_default = _convert_fill_arg(default_value)
fill.default_factory = functools.partial(_default_arg, sanitized_default)
return fill # type: ignore[return-value]
return _get_defaultdict(_convert_fill_arg(fill))
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
# https://github.com/pytorch/vision/issues/6250
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
from transforms import InterpolationMode # usort: skip
from ._utils import is_simple_tensor # usort: skip
from ._meta import (
clamp_bounding_box,
convert_format_bounding_box,
convert_dtype_image_tensor,
convert_dtype,
convert_dtype_video,
convert_image_dtype,
get_dimensions_image_tensor,
get_dimensions_image_pil,
get_dimensions,
get_num_frames_video,
get_num_frames,
get_image_num_channels,
get_num_channels_image_tensor,
get_num_channels_image_pil,
get_num_channels_video,
get_num_channels,
get_spatial_size_bounding_box,
get_spatial_size_image_tensor,
get_spatial_size_image_pil,
get_spatial_size_mask,
get_spatial_size_video,
get_spatial_size,
) # usort: skip
from ._augment import erase, erase_image_pil, erase_image_tensor, erase_video
from ._color import (
adjust_brightness,
adjust_brightness_image_pil,
adjust_brightness_image_tensor,
adjust_brightness_video,
adjust_contrast,
adjust_contrast_image_pil,
adjust_contrast_image_tensor,
adjust_contrast_video,
adjust_gamma,
adjust_gamma_image_pil,
adjust_gamma_image_tensor,
adjust_gamma_video,
adjust_hue,
adjust_hue_image_pil,
adjust_hue_image_tensor,
adjust_hue_video,
adjust_saturation,
adjust_saturation_image_pil,
adjust_saturation_image_tensor,
adjust_saturation_video,
adjust_sharpness,
adjust_sharpness_image_pil,
adjust_sharpness_image_tensor,
adjust_sharpness_video,
autocontrast,
autocontrast_image_pil,
autocontrast_image_tensor,
autocontrast_video,
equalize,
equalize_image_pil,
equalize_image_tensor,
equalize_video,
invert,
invert_image_pil,
invert_image_tensor,
invert_video,
posterize,
posterize_image_pil,
posterize_image_tensor,
posterize_video,
rgb_to_grayscale,
rgb_to_grayscale_image_pil,
rgb_to_grayscale_image_tensor,
solarize,
solarize_image_pil,
solarize_image_tensor,
solarize_video,
)
from ._geometry import (
affine,
affine_bounding_box,
affine_image_pil,
affine_image_tensor,
affine_mask,
affine_video,
center_crop,
center_crop_bounding_box,
center_crop_image_pil,
center_crop_image_tensor,
center_crop_mask,
center_crop_video,
crop,
crop_bounding_box,
crop_image_pil,
crop_image_tensor,
crop_mask,
crop_video,
elastic,
elastic_bounding_box,
elastic_image_pil,
elastic_image_tensor,
elastic_mask,
elastic_transform,
elastic_video,
five_crop,
five_crop_image_pil,
five_crop_image_tensor,
five_crop_video,
hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file
horizontal_flip,
horizontal_flip_bounding_box,
horizontal_flip_image_pil,
horizontal_flip_image_tensor,
horizontal_flip_mask,
horizontal_flip_video,
pad,
pad_bounding_box,
pad_image_pil,
pad_image_tensor,
pad_mask,
pad_video,
perspective,
perspective_bounding_box,
perspective_image_pil,
perspective_image_tensor,
perspective_mask,
perspective_video,
resize,
resize_bounding_box,
resize_image_pil,
resize_image_tensor,
resize_mask,
resize_video,
resized_crop,
resized_crop_bounding_box,
resized_crop_image_pil,
resized_crop_image_tensor,
resized_crop_mask,
resized_crop_video,
rotate,
rotate_bounding_box,
rotate_image_pil,
rotate_image_tensor,
rotate_mask,
rotate_video,
ten_crop,
ten_crop_image_pil,
ten_crop_image_tensor,
ten_crop_video,
vertical_flip,
vertical_flip_bounding_box,
vertical_flip_image_pil,
vertical_flip_image_tensor,
vertical_flip_mask,
vertical_flip_video,
vflip,
)
from ._misc import (
gaussian_blur,
gaussian_blur_image_pil,
gaussian_blur_image_tensor,
gaussian_blur_video,
normalize,
normalize_image_tensor,
normalize_video,
)
from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video
from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image
from ._deprecated import get_image_size, to_grayscale, to_tensor # usort: skip
from typing import Union
import PIL.Image
import torch
from util import datapoints
from transforms.functional import pil_to_tensor, to_pil_image
from ._utils import is_simple_tensor
def erase_image_tensor(
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
if not inplace:
image = image.clone()
image[..., i : i + h, j : j + w] = v
return image
@torch.jit.unused
def erase_image_pil(
image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> PIL.Image.Image:
t_img = pil_to_tensor(image)
output = erase_image_tensor(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return to_pil_image(output, mode=image.mode)
def erase_video(
video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
return erase_image_tensor(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
def erase(
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT],
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, datapoints.Image):
output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return datapoints.Image.wrap_like(inpt, output)
elif isinstance(inpt, datapoints.Video):
output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return datapoints.Video.wrap_like(inpt, output)
elif isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
from typing import Union
import PIL.Image
import torch
from torch.nn.functional import conv2d
from util import datapoints
from transforms import _functional_pil as _FP
from transforms._functional_tensor import _max_value
from ._meta import _num_value_bits, convert_dtype_image_tensor
from ._utils import is_simple_tensor
def _rgb_to_grayscale_image_tensor(
image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True
) -> torch.Tensor:
if image.shape[-3] == 1:
return image.clone()
r, g, b = image.unbind(dim=-3)
l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
l_img = l_img.unsqueeze(dim=-3)
if preserve_dtype:
l_img = l_img.to(image.dtype)
if num_output_channels == 3:
l_img = l_img.expand(image.shape)
return l_img
def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True)
rgb_to_grayscale_image_pil = _FP.to_grayscale
def rgb_to_grayscale(
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], num_output_channels: int = 1
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.rgb_to_grayscale(num_output_channels=num_output_channels)
elif isinstance(inpt, PIL.Image.Image):
return rgb_to_grayscale_image_pil(inpt, num_output_channels=num_output_channels)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
ratio = float(ratio)
fp = image1.is_floating_point()
bound = _max_value(image1.dtype)
output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound)
return output if fp else output.to(image1.dtype)
def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor:
if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
fp = image.is_floating_point()
bound = _max_value(image.dtype)
output = image.mul(brightness_factor).clamp_(0, bound)
return output if fp else output.to(image.dtype)
adjust_brightness_image_pil = _FP.adjust_brightness
def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor:
return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor)
def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.adjust_brightness(brightness_factor=brightness_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor:
if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
if c == 1: # Match PIL behaviour
return image
grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False)
if not image.is_floating_point():
grayscale_image = grayscale_image.floor_()
return _blend(image, grayscale_image, saturation_factor)
adjust_saturation_image_pil = _FP.adjust_saturation
def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor:
return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor)
def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float) -> datapoints._InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)
):
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.adjust_saturation(saturation_factor=saturation_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor:
if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
fp = image.is_floating_point()
if c == 3:
grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False)
if not fp:
grayscale_image = grayscale_image.floor_()
else:
grayscale_image = image if fp else image.to(torch.float32)
mean = torch.mean(grayscale_image, dim=(-3, -2, -1), keepdim=True)
return _blend(image, mean, contrast_factor)
adjust_contrast_image_pil = _FP.adjust_contrast
def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor:
return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor)
def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.adjust_contrast(contrast_factor=contrast_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
num_channels, height, width = image.shape[-3:]
if num_channels not in (1, 3):
raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}")
if sharpness_factor < 0:
raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")
if image.numel() == 0 or height <= 2 or width <= 2:
return image
bound = _max_value(image.dtype)
fp = image.is_floating_point()
shape = image.shape
if image.ndim > 4:
image = image.reshape(-1, num_channels, height, width)
needs_unsquash = True
else:
needs_unsquash = False
# The following is a normalized 3x3 kernel with 1s in the edges and a 5 in the middle.
kernel_dtype = image.dtype if fp else torch.float32
a, b = 1.0 / 13.0, 5.0 / 13.0
kernel = torch.tensor([[a, a, a], [a, b, a], [a, a, a]], dtype=kernel_dtype, device=image.device)
kernel = kernel.expand(num_channels, 1, 3, 3)
# We copy and cast at the same time to avoid modifications on the original data
output = image.to(dtype=kernel_dtype, copy=True)
blurred_degenerate = conv2d(output, kernel, groups=num_channels)
if not fp:
# it is better to round before cast
blurred_degenerate = blurred_degenerate.round_()
# Create a view on the underlying output while pointing at the same data. We do this to avoid indexing twice.
view = output[..., 1:-1, 1:-1]
# We speed up blending by minimizing flops and doing in-place. The 2 blend options are mathematically equivalent:
# x+(1-r)*(y-x) = x + (1-r)*y - (1-r)*x = x*r + y*(1-r)
view.add_(blurred_degenerate.sub_(view), alpha=(1.0 - sharpness_factor))
# The actual data of output have been modified by the above. We only need to clamp and cast now.
output = output.clamp_(0, bound)
if not fp:
output = output.to(image.dtype)
if needs_unsquash:
output = output.reshape(shape)
return output
adjust_sharpness_image_pil = _FP.adjust_sharpness
def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor)
def adjust_sharpness(inpt: datapoints._InputTypeJIT, sharpness_factor: float) -> datapoints._InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)
):
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.adjust_sharpness(sharpness_factor=sharpness_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:
r, g, _ = image.unbind(dim=-3)
# Implementation is based on
# https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/src/libImaging/Convert.c#L330
minc, maxc = torch.aminmax(image, dim=-3)
# The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
# from happening in the results, because
# + S channel has division by `maxc`, which is zero only if `maxc = minc`
# + H channel has division by `(maxc - minc)`.
#
# Instead of overwriting NaN afterwards, we just prevent it from occurring so
# we don't need to deal with it in case we save the NaN in a buffer in
# backprop, if it is ever supported, but it doesn't hurt to do so.
eqc = maxc == minc
channels_range = maxc - minc
# Since `eqc => channels_range = 0`, replacing denominator with 1 when `eqc` is fine.
ones = torch.ones_like(maxc)
s = channels_range / torch.where(eqc, ones, maxc)
# Note that `eqc => maxc = minc = r = g = b`. So the following calculation
# of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
# would not matter what values `rc`, `gc`, and `bc` have here, and thus
# replacing denominator with 1 when `eqc` is fine.
channels_range_divisor = torch.where(eqc, ones, channels_range).unsqueeze_(dim=-3)
rc, gc, bc = ((maxc.unsqueeze(dim=-3) - image) / channels_range_divisor).unbind(dim=-3)
mask_maxc_neq_r = maxc != r
mask_maxc_eq_g = maxc == g
hg = rc.add(2.0).sub_(bc).mul_(mask_maxc_eq_g & mask_maxc_neq_r)
hr = bc.sub_(gc).mul_(~mask_maxc_neq_r)
hb = gc.add_(4.0).sub_(rc).mul_(mask_maxc_neq_r.logical_and_(mask_maxc_eq_g.logical_not_()))
h = hr.add_(hg).add_(hb)
h = h.mul_(1.0 / 6.0).add_(1.0).fmod_(1.0)
return torch.stack((h, s, maxc), dim=-3)
def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
h, s, v = img.unbind(dim=-3)
h6 = h.mul(6)
i = torch.floor(h6)
f = h6.sub_(i)
i = i.to(dtype=torch.int32)
sxf = s * f
one_minus_s = 1.0 - s
q = (1.0 - sxf).mul_(v).clamp_(0.0, 1.0)
t = sxf.add_(one_minus_s).mul_(v).clamp_(0.0, 1.0)
p = one_minus_s.mul_(v).clamp_(0.0, 1.0)
i.remainder_(6)
mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
a1 = torch.stack((v, q, p, p, t, v), dim=-3)
a2 = torch.stack((t, v, v, q, p, p), dim=-3)
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3), dim=-4)
return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3)
def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
if c == 1: # Match PIL behaviour
return image
if image.numel() == 0:
# exit earlier on empty images
return image
orig_dtype = image.dtype
image = convert_dtype_image_tensor(image, torch.float32)
image = _rgb_to_hsv(image)
h, s, v = image.unbind(dim=-3)
h.add_(hue_factor).remainder_(1.0)
image = torch.stack((h, s, v), dim=-3)
image_hue_adj = _hsv_to_rgb(image)
return convert_dtype_image_tensor(image_hue_adj, orig_dtype)
adjust_hue_image_pil = _FP.adjust_hue
def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
return adjust_hue_image_tensor(video, hue_factor=hue_factor)
def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.adjust_hue(hue_factor=hue_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor:
if gamma < 0:
raise ValueError("Gamma should be a non-negative real number")
# The input image is either assumed to be at [0, 1] scale (if float) or is converted to that scale (if integer).
# Since the gamma is non-negative, the output remains at [0, 1] scale.
if not torch.is_floating_point(image):
output = convert_dtype_image_tensor(image, torch.float32).pow_(gamma)
else:
output = image.pow(gamma)
if gain != 1.0:
# The clamp operation is needed only if multiplication is performed. It's only when gain != 1, that the scale
# of the output can go beyond [0, 1].
output = output.mul_(gain).clamp_(0.0, 1.0)
return convert_dtype_image_tensor(output, image.dtype)
adjust_gamma_image_pil = _FP.adjust_gamma
def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain)
def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.adjust_gamma(gamma=gamma, gain=gain)
elif isinstance(inpt, PIL.Image.Image):
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
if image.is_floating_point():
levels = 1 << bits
return image.mul(levels).floor_().clamp_(0, levels - 1).mul_(1.0 / levels)
else:
num_value_bits = _num_value_bits(image.dtype)
if bits >= num_value_bits:
return image
mask = ((1 << bits) - 1) << (num_value_bits - bits)
return image & mask
posterize_image_pil = _FP.posterize
def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
return posterize_image_tensor(video, bits=bits)
def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return posterize_image_tensor(inpt, bits=bits)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.posterize(bits=bits)
elif isinstance(inpt, PIL.Image.Image):
return posterize_image_pil(inpt, bits=bits)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor:
if threshold > _max_value(image.dtype):
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")
return torch.where(image >= threshold, invert_image_tensor(image), image)
solarize_image_pil = _FP.solarize
def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
return solarize_image_tensor(video, threshold=threshold)
def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return solarize_image_tensor(inpt, threshold=threshold)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.solarize(threshold=threshold)
elif isinstance(inpt, PIL.Image.Image):
return solarize_image_pil(inpt, threshold=threshold)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
if image.numel() == 0:
# exit earlier on empty images
return image
bound = _max_value(image.dtype)
fp = image.is_floating_point()
float_image = image if fp else image.to(torch.float32)
minimum = float_image.amin(dim=(-2, -1), keepdim=True)
maximum = float_image.amax(dim=(-2, -1), keepdim=True)
eq_idxs = maximum == minimum
inv_scale = maximum.sub_(minimum).mul_(1.0 / bound)
minimum[eq_idxs] = 0.0
inv_scale[eq_idxs] = 1.0
if fp:
diff = float_image.sub(minimum)
else:
diff = float_image.sub_(minimum)
return diff.div_(inv_scale).clamp_(0, bound).to(image.dtype)
autocontrast_image_pil = _FP.autocontrast
def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
return autocontrast_image_tensor(video)
def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return autocontrast_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.autocontrast()
elif isinstance(inpt, PIL.Image.Image):
return autocontrast_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.numel() == 0:
return image
# 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that
# would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for
# `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely
# unfeasible for `torch.int64`.
# 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower
# and more complicated to implement than a simple conversion and a fast histogram implementation for integers.
# Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is
# by far the most common, we choose it as base.
output_dtype = image.dtype
image = convert_dtype_image_tensor(image, torch.uint8)
# The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image
# corresponds to adding 1 to index 127 in the histogram.
batch_shape = image.shape[:-2]
flat_image = image.flatten(start_dim=-2).to(torch.long)
hist = flat_image.new_zeros(batch_shape + (256,), dtype=torch.int32)
hist.scatter_add_(dim=-1, index=flat_image, src=hist.new_ones(1).expand_as(flat_image))
cum_hist = hist.cumsum(dim=-1)
# The simplest form of lookup-table (LUT) that also achieves histogram equalization is
# `lut = cum_hist / flat_image.shape[-1] * 255`
# However, PIL uses a more elaborate scheme:
# https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385
# `lut = ((cum_hist + num_non_max_pixels // (2 * 255)) // num_non_max_pixels) * 255`
# The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum
# value. Thus, the "max" in `num_non_max_pixels` does not refer to 255 as the maximum value of uint8 images, but
# rather the maximum value in the image, which might be or not be 255.
index = cum_hist.argmax(dim=-1)
num_non_max_pixels = flat_image.shape[-1] - hist.gather(dim=-1, index=index.unsqueeze_(-1))
# This is performance optimization that saves us one multiplication later. With this, the LUT computation simplifies
# to `lut = (cum_hist + step // 2) // step` and thus saving the final multiplication by 255 while keeping the
# division count the same. PIL uses the variable name `step` for this, so we keep that for easier comparison.
step = num_non_max_pixels.div_(255, rounding_mode="floor")
# Although it looks like we could return early if we find `step == 0` like PIL does, that is unfortunately not as
# easy due to our support for batched images. We can only return early if `(step == 0).all()` holds. If it doesn't,
# we have to go through the computation below anyway. Since `step == 0` is an edge case anyway, it makes no sense to
# pay the runtime cost for checking it every time.
valid_equalization = step.ne(0).unsqueeze_(-1)
# `lut[k]` is computed with `cum_hist[k-1]` with `lut[0] == (step // 2) // step == 0`. Thus, we perform the
# computation only for `lut[1:]` with `cum_hist[:-1]` and add `lut[0] == 0` afterwards.
cum_hist = cum_hist[..., :-1]
(
cum_hist.add_(step // 2)
# We need the `clamp_`(min=1) call here to avoid zero division since they fail for integer dtypes. This has no
# effect on the returned result of this kernel since images inside the batch with `step == 0` are returned as is
# instead of equalized version.
.div_(step.clamp_(min=1), rounding_mode="floor")
# We need the `clamp_` call here since PILs LUT computation scheme can produce values outside the valid value
# range of uint8 images
.clamp_(0, 255)
)
lut = cum_hist.to(torch.uint8)
lut = torch.cat([lut.new_zeros(1).expand(batch_shape + (1,)), lut], dim=-1)
equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image)
output = torch.where(valid_equalization, equalized_image, image)
return convert_dtype_image_tensor(output, output_dtype)
equalize_image_pil = _FP.equalize
def equalize_video(video: torch.Tensor) -> torch.Tensor:
return equalize_image_tensor(video)
def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return equalize_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.equalize()
elif isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.is_floating_point():
return 1.0 - image
elif image.dtype == torch.uint8:
return image.bitwise_not()
else: # signed integer dtypes
# We can't use `Tensor.bitwise_not` here, since we want to retain the leading zero bit that encodes the sign
return image.bitwise_xor((1 << _num_value_bits(image.dtype)) - 1)
invert_image_pil = _FP.invert
def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image_tensor(video)
def invert(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return invert_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.invert()
elif isinstance(inpt, PIL.Image.Image):
return invert_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
import warnings
from typing import Any, List, Union
import PIL.Image
import torch
from util import datapoints
from transforms import functional as _F
@torch.jit.unused
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = "convert_color_space(..., color_space=datapoints.ColorSpace.GRAY)"
if num_output_channels == 3:
replacement = f"convert_color_space({replacement}, color_space=datapoints.ColorSpace.RGB)"
warnings.warn(
f"The function `to_grayscale(...{call})` is deprecated in will be removed in a future release. "
f"Instead, please use `{replacement}`.",
)
return _F.to_grayscale(inpt, num_output_channels=num_output_channels)
@torch.jit.unused
def to_tensor(inpt: Any) -> torch.Tensor:
warnings.warn(
"The function `to_tensor(...)` is deprecated and will be removed in a future release. "
"Instead, please use `to_image_tensor(...)` followed by `convert_image_dtype(...)`."
)
return _F.to_tensor(inpt)
def get_image_size(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]:
warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."
)
return _F.get_image_size(inpt)
import math
import numbers
import warnings
from typing import List, Optional, Sequence, Tuple, Union
import PIL.Image
import torch
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
from util import datapoints
from transforms import _functional_pil as _FP
from transforms._functional_tensor import _pad_symmetric
from transforms.functional import (
_check_antialias,
_compute_resized_output_size as __compute_resized_output_size,
_get_perspective_coeffs,
_interpolation_modes_from_int,
InterpolationMode,
pil_modes_mapping,
pil_to_tensor,
to_pil_image,
)
from ._meta import clamp_bounding_box, convert_format_bounding_box, get_spatial_size_image_pil
from ._utils import is_simple_tensor
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
elif not isinstance(interpolation, InterpolationMode):
raise ValueError(
f"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, "
f"but got {interpolation}."
)
return interpolation
def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-1)
horizontal_flip_image_pil = _FP.hflip
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image_tensor(mask)
def horizontal_flip_bounding_box(
bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int]
) -> torch.Tensor:
shape = bounding_box.shape
bounding_box = bounding_box.clone().reshape(-1, 4)
if format == datapoints.BoundingBoxFormat.XYXY:
bounding_box[:, [2, 0]] = bounding_box[:, [0, 2]].sub_(spatial_size[1]).neg_()
elif format == datapoints.BoundingBoxFormat.XYWH:
bounding_box[:, 0].add_(bounding_box[:, 2]).sub_(spatial_size[1]).neg_()
else: # format == datapoints.BoundingBoxFormat.CXCYWH:
bounding_box[:, 0].sub_(spatial_size[1]).neg_()
return bounding_box.reshape(shape)
def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image_tensor(video)
def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return horizontal_flip_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.horizontal_flip()
elif isinstance(inpt, PIL.Image.Image):
return horizontal_flip_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-2)
vertical_flip_image_pil = _FP.vflip
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return vertical_flip_image_tensor(mask)
def vertical_flip_bounding_box(
bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int]
) -> torch.Tensor:
shape = bounding_box.shape
bounding_box = bounding_box.clone().reshape(-1, 4)
if format == datapoints.BoundingBoxFormat.XYXY:
bounding_box[:, [1, 3]] = bounding_box[:, [3, 1]].sub_(spatial_size[0]).neg_()
elif format == datapoints.BoundingBoxFormat.XYWH:
bounding_box[:, 1].add_(bounding_box[:, 3]).sub_(spatial_size[0]).neg_()
else: # format == datapoints.BoundingBoxFormat.CXCYWH:
bounding_box[:, 1].sub_(spatial_size[0]).neg_()
return bounding_box.reshape(shape)
def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
return vertical_flip_image_tensor(video)
def vertical_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return vertical_flip_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.vertical_flip()
elif isinstance(inpt, PIL.Image.Image):
return vertical_flip_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
hflip = horizontal_flip
vflip = vertical_flip
def _compute_resized_output_size(
spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> List[int]:
if isinstance(size, int):
size = [size]
elif max_size is not None and len(size) != 1:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
return __compute_resized_output_size(spatial_size, size=size, max_size=max_size)
def resize_image_tensor(
image: torch.Tensor,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> torch.Tensor:
interpolation = _check_interpolation(interpolation)
antialias = _check_antialias(img=image, antialias=antialias, interpolation=interpolation)
assert not isinstance(antialias, str)
antialias = False if antialias is None else antialias
align_corners: Optional[bool] = None
if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC:
align_corners = False
else:
# The default of antialias should be True from 0.17, so we don't warn or
# error if other interpolation modes are used. This is documented.
antialias = False
shape = image.shape
num_channels, old_height, old_width = shape[-3:]
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
if image.numel() > 0:
image = image.reshape(-1, num_channels, old_height, old_width)
dtype = image.dtype
need_cast = dtype not in (torch.float32, torch.float64)
if need_cast:
image = image.to(dtype=torch.float32)
image = interpolate(
image,
size=[new_height, new_width],
mode=interpolation.value,
align_corners=align_corners,
antialias=antialias,
)
if need_cast:
if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
image = image.clamp_(min=0, max=255)
image = image.round_().to(dtype=dtype)
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
@torch.jit.unused
def resize_image_pil(
image: PIL.Image.Image,
size: Union[Sequence[int], int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
) -> PIL.Image.Image:
interpolation = _check_interpolation(interpolation)
size = _compute_resized_output_size(image.size[::-1], size=size, max_size=max_size) # type: ignore[arg-type]
return _FP.resize(image, size, interpolation=pil_modes_mapping[interpolation])
def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = resize_image_tensor(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
if needs_squeeze:
output = output.squeeze(0)
return output
def resize_bounding_box(
bounding_box: torch.Tensor, spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> Tuple[torch.Tensor, Tuple[int, int]]:
old_height, old_width = spatial_size
new_height, new_width = _compute_resized_output_size(spatial_size, size=size, max_size=max_size)
w_ratio = new_width / old_width
h_ratio = new_height / old_height
ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_box.device)
return (
bounding_box.mul(ratios).to(bounding_box.dtype),
(new_height, new_width),
)
def resize_video(
video: torch.Tensor,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> torch.Tensor:
return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
def resize(
inpt: datapoints._InputTypeJIT,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, PIL.Image.Image):
if antialias is False:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _affine_parse_args(
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
center: Optional[List[float]] = None,
) -> Tuple[float, List[float], List[float], Optional[List[float]]]:
if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float")
if not isinstance(translate, (list, tuple)):
raise TypeError("Argument translate should be a sequence")
if len(translate) != 2:
raise ValueError("Argument translate should be a sequence of length 2")
if scale <= 0.0:
raise ValueError("Argument scale should be positive")
if not isinstance(shear, (numbers.Number, (list, tuple))):
raise TypeError("Shear should be either a single value or a sequence of two values")
if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")
if isinstance(angle, int):
angle = float(angle)
if isinstance(translate, tuple):
translate = list(translate)
if isinstance(shear, numbers.Number):
shear = [shear, 0.0]
if isinstance(shear, tuple):
shear = list(shear)
if len(shear) == 1:
shear = [shear[0], shear[0]]
if len(shear) != 2:
raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
if center is not None:
if not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence")
else:
center = [float(c) for c in center]
return angle, translate, shear, center
def _get_inverse_affine_matrix(
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
) -> List[float]:
# Helper method to compute inverse matrix for affine transformation
# Pillow requires inverse affine transformation matrix:
# Affine matrix is : M = T * C * RotateScaleShear * C^-1
#
# where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
# C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
# RotateScaleShear is rotation with scale and shear matrix
#
# RotateScaleShear(a, s, (sx, sy)) =
# = R(a) * S(s) * SHy(sy) * SHx(sx)
# = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
# [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
# [ 0 , 0 , 1 ]
# where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
# SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
# [0, 1 ] [-tan(s), 1]
#
# Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1
rot = math.radians(angle)
sx = math.radians(shear[0])
sy = math.radians(shear[1])
cx, cy = center
tx, ty = translate
# Cached results
cos_sy = math.cos(sy)
tan_sx = math.tan(sx)
rot_minus_sy = rot - sy
cx_plus_tx = cx + tx
cy_plus_ty = cy + ty
# Rotate Scale Shear (RSS) without scaling
a = math.cos(rot_minus_sy) / cos_sy
b = -(a * tan_sx + math.sin(rot))
c = math.sin(rot_minus_sy) / cos_sy
d = math.cos(rot) - c * tan_sx
if inverted:
# Inverted rotation matrix with scale and shear
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0]
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
# and then apply center translation: C * RSS^-1 * C^-1 * T^-1
matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty
matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty
else:
matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0]
# Apply inverse of center translation: RSS * C^-1
# and then apply translation and center : T * C * RSS * C^-1
matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy
matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy
return matrix
def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
# Inspired of PIL implementation:
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
# Points are shifted due to affine matrix torch convention about
# the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
half_w = 0.5 * w
half_h = 0.5 * h
pts = torch.tensor(
[
[-half_w, -half_h, 1.0],
[-half_w, half_h, 1.0],
[half_w, half_h, 1.0],
[half_w, -half_h, 1.0],
]
)
theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
new_pts = torch.matmul(pts, theta.T)
min_vals, max_vals = new_pts.aminmax(dim=0)
# shift points to [0, w] and [0, h] interval to match PIL results
halfs = torch.tensor((half_w, half_h))
min_vals.add_(halfs)
max_vals.add_(halfs)
# Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
tol = 1e-4
inv_tol = 1.0 / tol
cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_()
cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_()
size = cmax.sub_(cmin)
return int(size[0]), int(size[1]) # w, h
def _apply_grid_transform(
img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints._FillTypeJIT
) -> torch.Tensor:
# We are using context knowledge that grid should have float dtype
fp = img.dtype == grid.dtype
float_img = img if fp else img.to(grid.dtype)
shape = float_img.shape
if shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(shape[0], -1, -1, -1)
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
mask = torch.ones((shape[0], 1, shape[2], shape[3]), dtype=float_img.dtype, device=float_img.device)
float_img = torch.cat((float_img, mask), dim=1)
float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False)
# Fill with required color
if fill is not None:
float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
mask = mask.expand_as(float_img)
fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] # type: ignore[arg-type]
fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
if mode == "nearest":
bool_mask = mask < 0.5
float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
else: # 'bilinear'
# The following is mathematically equivalent to:
# img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)
img = float_img.round_().to(img.dtype) if not fp else float_img
return img
def _assert_grid_transform_inputs(
image: torch.Tensor,
matrix: Optional[List[float]],
interpolation: str,
fill: datapoints._FillTypeJIT,
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
) -> None:
if matrix is not None:
if not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list")
elif len(matrix) != 6:
raise ValueError("Argument matrix should have 6 float values")
if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values")
if fill is not None:
if isinstance(fill, (tuple, list)):
length = len(fill)
num_channels = image.shape[-3]
if length > 1 and length != num_channels:
raise ValueError(
"The number of elements in 'fill' cannot broadcast to match the number of "
f"channels of the image ({length} != {num_channels})"
)
elif not isinstance(fill, (int, float)):
raise ValueError("Argument fill should be either int, float, tuple or list")
if interpolation not in supported_interpolation_modes:
raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
def _affine_grid(
theta: torch.Tensor,
w: int,
h: int,
ow: int,
oh: int,
) -> torch.Tensor:
# https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
# AffineGridGenerator.cpp#L18
# Difference with AffineGridGenerator is that:
# 1) we normalize grid values after applying theta
# 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
dtype = theta.dtype
device = theta.device
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
x_grid = torch.linspace((1.0 - ow) * 0.5, (ow - 1.0) * 0.5, steps=ow, device=device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace((1.0 - oh) * 0.5, (oh - 1.0) * 0.5, steps=oh, device=device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta = theta.transpose(1, 2).div_(torch.tensor([0.5 * w, 0.5 * h], dtype=dtype, device=device))
output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
return output_grid.view(1, oh, ow, 2)
def affine_image_tensor(
image: torch.Tensor,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints._FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
interpolation = _check_interpolation(interpolation)
if image.numel() == 0:
return image
shape = image.shape
ndim = image.ndim
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
elif ndim == 3:
image = image.unsqueeze(0)
needs_unsquash = True
else:
needs_unsquash = False
height, width = shape[-2:]
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
center_f = [0.0, 0.0]
if center is not None:
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
translate_f = [float(t) for t in translate]
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
if needs_unsquash:
output = output.reshape(shape)
return output
@torch.jit.unused
def affine_image_pil(
image: PIL.Image.Image,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints._FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> PIL.Image.Image:
interpolation = _check_interpolation(interpolation)
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None:
height, width = get_spatial_size_image_pil(image)
center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
def _affine_bounding_box_with_expand(
bounding_box: torch.Tensor,
format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int],
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
center: Optional[List[float]] = None,
expand: bool = False,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
if bounding_box.numel() == 0:
return bounding_box, spatial_size
original_shape = bounding_box.shape
original_dtype = bounding_box.dtype
bounding_box = bounding_box.clone() if bounding_box.is_floating_point() else bounding_box.float()
dtype = bounding_box.dtype
device = bounding_box.device
bounding_box = (
convert_format_bounding_box(
bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
)
).reshape(-1, 4)
angle, translate, shear, center = _affine_parse_args(
angle, translate, scale, shear, InterpolationMode.NEAREST, center
)
if center is None:
height, width = spatial_size
center = [width * 0.5, height * 0.5]
affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
transposed_affine_matrix = (
torch.tensor(
affine_vector,
dtype=dtype,
device=device,
)
.reshape(2, 3)
.T
)
# 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
# Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
# 2) Now let's transform the points using affine matrix
transformed_points = torch.matmul(points, transposed_affine_matrix)
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
# and compute bounding box from 4 transformed points:
transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
if expand:
# Compute minimum point for transformed image frame:
# Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
height, width = spatial_size
points = torch.tensor(
[
[0.0, 0.0, 1.0],
[0.0, float(height), 1.0],
[float(width), float(height), 1.0],
[float(width), 0.0, 1.0],
],
dtype=dtype,
device=device,
)
new_points = torch.matmul(points, transposed_affine_matrix)
tr = torch.amin(new_points, dim=0, keepdim=True)
# Translate bounding boxes
out_bboxes.sub_(tr.repeat((1, 2)))
# Estimate meta-data for image with inverted=True and with center=[0,0]
affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
spatial_size = (new_height, new_width)
out_bboxes = clamp_bounding_box(out_bboxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size)
out_bboxes = convert_format_bounding_box(
out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
).reshape(original_shape)
out_bboxes = out_bboxes.to(original_dtype)
return out_bboxes, spatial_size
def affine_bounding_box(
bounding_box: torch.Tensor,
format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int],
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
center: Optional[List[float]] = None,
) -> torch.Tensor:
out_box, _ = _affine_bounding_box_with_expand(
bounding_box,
format=format,
spatial_size=spatial_size,
angle=angle,
translate=translate,
scale=scale,
shear=shear,
center=center,
expand=False,
)
return out_box
def affine_mask(
mask: torch.Tensor,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
fill: datapoints._FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = affine_image_tensor(
mask,
angle=angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=InterpolationMode.NEAREST,
fill=fill,
center=center,
)
if needs_squeeze:
output = output.squeeze(0)
return output
def affine_video(
video: torch.Tensor,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints._FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
return affine_image_tensor(
video,
angle=angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
def affine(
inpt: datapoints._InputTypeJIT,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints._FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> datapoints._InputTypeJIT:
# TODO: consider deprecating integers from angle and shear on the future
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return affine_image_tensor(
inpt,
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.affine(
angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center
)
elif isinstance(inpt, PIL.Image.Image):
return affine_image_pil(
inpt,
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def rotate_image_tensor(
image: torch.Tensor,
angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: datapoints._FillTypeJIT = None,
) -> torch.Tensor:
interpolation = _check_interpolation(interpolation)
shape = image.shape
num_channels, height, width = shape[-3:]
center_f = [0.0, 0.0]
if center is not None:
if expand:
# TODO: Do we actually want to warn, or just document this?
warnings.warn("The provided center argument has no effect on the result if expand is True")
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
if image.numel() > 0:
image = image.reshape(-1, num_channels, height, width)
_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
new_height, new_width = output.shape[-2:]
else:
output = image
new_width, new_height = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
return output.reshape(shape[:-3] + (num_channels, new_height, new_width))
@torch.jit.unused
def rotate_image_pil(
image: PIL.Image.Image,
angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: datapoints._FillTypeJIT = None,
) -> PIL.Image.Image:
interpolation = _check_interpolation(interpolation)
if center is not None and expand:
warnings.warn("The provided center argument has no effect on the result if expand is True")
center = None
return _FP.rotate(
image, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
)
def rotate_bounding_box(
bounding_box: torch.Tensor,
format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int],
angle: float,
expand: bool = False,
center: Optional[List[float]] = None,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
if center is not None and expand:
warnings.warn("The provided center argument has no effect on the result if expand is True")
center = None
return _affine_bounding_box_with_expand(
bounding_box,
format=format,
spatial_size=spatial_size,
angle=-angle,
translate=[0.0, 0.0],
scale=1.0,
shear=[0.0, 0.0],
center=center,
expand=expand,
)
def rotate_mask(
mask: torch.Tensor,
angle: float,
expand: bool = False,
center: Optional[List[float]] = None,
fill: datapoints._FillTypeJIT = None,
) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = rotate_image_tensor(
mask,
angle=angle,
expand=expand,
interpolation=InterpolationMode.NEAREST,
fill=fill,
center=center,
)
if needs_squeeze:
output = output.squeeze(0)
return output
def rotate_video(
video: torch.Tensor,
angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: datapoints._FillTypeJIT = None,
) -> torch.Tensor:
return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
def rotate(
inpt: datapoints._InputTypeJIT,
angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: datapoints._FillTypeJIT = None,
) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
elif isinstance(inpt, PIL.Image.Image):
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
elif isinstance(padding, (tuple, list)):
if len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
elif len(padding) == 4:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
else:
raise ValueError(
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)
else:
raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}")
return [pad_left, pad_right, pad_top, pad_bottom]
def pad_image_tensor(
image: torch.Tensor,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> torch.Tensor:
# Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses
# `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
# internally.
torch_padding = _parse_pad_padding(padding)
if padding_mode not in ("constant", "edge", "reflect", "symmetric"):
raise ValueError(
f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
f"but got `'{padding_mode}'`."
)
if fill is None:
fill = 0
if isinstance(fill, (int, float)):
return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
elif len(fill) == 1:
return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode)
else:
return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
def _pad_with_scalar_fill(
image: torch.Tensor,
torch_padding: List[int],
fill: Union[int, float],
padding_mode: str,
) -> torch.Tensor:
shape = image.shape
num_channels, height, width = shape[-3:]
batch_size = 1
for s in shape[:-3]:
batch_size *= s
image = image.reshape(batch_size, num_channels, height, width)
if padding_mode == "edge":
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
# name.
padding_mode = "replicate"
if padding_mode == "constant":
image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
elif padding_mode in ("reflect", "replicate"):
# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
# TODO: See https://github.com/pytorch/pytorch/issues/40763
dtype = image.dtype
if not image.is_floating_point():
needs_cast = True
image = image.to(torch.float32)
else:
needs_cast = False
image = torch_pad(image, torch_padding, mode=padding_mode)
if needs_cast:
image = image.to(dtype)
else: # padding_mode == "symmetric"
image = _pad_symmetric(image, torch_padding)
new_height, new_width = image.shape[-2:]
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
# TODO: This should be removed once torch_pad supports non-scalar padding values
def _pad_with_vector_fill(
image: torch.Tensor,
torch_padding: List[int],
fill: List[float],
padding_mode: str,
) -> torch.Tensor:
if padding_mode != "constant":
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
left, right, top, bottom = torch_padding
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)
if top > 0:
output[..., :top, :] = fill
if left > 0:
output[..., :, :left] = fill
if bottom > 0:
output[..., -bottom:, :] = fill
if right > 0:
output[..., :, -right:] = fill
return output
pad_image_pil = _FP.pad
def pad_mask(
mask: torch.Tensor,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> torch.Tensor:
if fill is None:
fill = 0
if isinstance(fill, (tuple, list)):
raise ValueError("Non-scalar fill value is not supported")
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = pad_image_tensor(mask, padding=padding, fill=fill, padding_mode=padding_mode)
if needs_squeeze:
output = output.squeeze(0)
return output
def pad_bounding_box(
bounding_box: torch.Tensor,
format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int],
padding: List[int],
padding_mode: str = "constant",
) -> Tuple[torch.Tensor, Tuple[int, int]]:
if padding_mode not in ["constant"]:
# TODO: add support of other padding modes
raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")
left, right, top, bottom = _parse_pad_padding(padding)
if format == datapoints.BoundingBoxFormat.XYXY:
pad = [left, top, left, top]
else:
pad = [left, top, 0, 0]
bounding_box = bounding_box + torch.tensor(pad, dtype=bounding_box.dtype, device=bounding_box.device)
height, width = spatial_size
height += top + bottom
width += left + right
spatial_size = (height, width)
return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size
def pad_video(
video: torch.Tensor,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> torch.Tensor:
return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode)
def pad(
inpt: datapoints._InputTypeJIT,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.pad(padding, fill=fill, padding_mode=padding_mode)
elif isinstance(inpt, PIL.Image.Image):
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
h, w = image.shape[-2:]
right = left + width
bottom = top + height
if left < 0 or top < 0 or right > w or bottom > h:
image = image[..., max(top, 0) : bottom, max(left, 0) : right]
torch_padding = [
max(min(right, 0) - left, 0),
max(right - max(w, left), 0),
max(min(bottom, 0) - top, 0),
max(bottom - max(h, top), 0),
]
return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
return image[..., top:bottom, left:right]
crop_image_pil = _FP.crop
def crop_bounding_box(
bounding_box: torch.Tensor,
format: datapoints.BoundingBoxFormat,
top: int,
left: int,
height: int,
width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
# Crop or implicit pad if left and/or top have negative values:
if format == datapoints.BoundingBoxFormat.XYXY:
sub = [left, top, left, top]
else:
sub = [left, top, 0, 0]
bounding_box = bounding_box - torch.tensor(sub, dtype=bounding_box.dtype, device=bounding_box.device)
spatial_size = (height, width)
return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = crop_image_tensor(mask, top, left, height, width)
if needs_squeeze:
output = output.squeeze(0)
return output
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
return crop_image_tensor(video, top, left, height, width)
def crop(inpt: datapoints._InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return crop_image_tensor(inpt, top, left, height, width)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.crop(top, left, height, width)
elif isinstance(inpt, PIL.Image.Image):
return crop_image_pil(inpt, top, left, height, width)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
# src/libImaging/Geometry.c#L394
#
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
#
theta1 = torch.tensor(
[[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
)
theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device))
shape = (1, oh * ow, 3)
output_grid1 = base_grid.view(shape).bmm(rescaled_theta1)
output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2))
output_grid = output_grid1.div_(output_grid2).sub_(1.0)
return output_grid.view(1, oh, ow, 2)
def _perspective_coefficients(
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
coefficients: Optional[List[float]],
) -> List[float]:
if coefficients is not None:
if startpoints is not None and endpoints is not None:
raise ValueError("The startpoints/endpoints and the coefficients shouldn't be defined concurrently.")
elif len(coefficients) != 8:
raise ValueError("Argument coefficients should have 8 float values")
return coefficients
elif startpoints is not None and endpoints is not None:
return _get_perspective_coeffs(startpoints, endpoints)
else:
raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.")
def perspective_image_tensor(
image: torch.Tensor,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> torch.Tensor:
perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
interpolation = _check_interpolation(interpolation)
if image.numel() == 0:
return image
shape = image.shape
ndim = image.ndim
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
elif ndim == 3:
image = image.unsqueeze(0)
needs_unsquash = True
else:
needs_unsquash = False
_assert_grid_transform_inputs(
image,
matrix=None,
interpolation=interpolation.value,
fill=fill,
supported_interpolation_modes=["nearest", "bilinear"],
coeffs=perspective_coeffs,
)
oh, ow = shape[-2:]
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
if needs_unsquash:
output = output.reshape(shape)
return output
@torch.jit.unused
def perspective_image_pil(
image: PIL.Image.Image,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BICUBIC,
fill: datapoints._FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> PIL.Image.Image:
perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
interpolation = _check_interpolation(interpolation)
return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
def perspective_bounding_box(
bounding_box: torch.Tensor,
format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int],
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
coefficients: Optional[List[float]] = None,
) -> torch.Tensor:
if bounding_box.numel() == 0:
return bounding_box
perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
original_shape = bounding_box.shape
# TODO: first cast to float if bbox is int64 before convert_format_bounding_box
bounding_box = (
convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
).reshape(-1, 4)
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
device = bounding_box.device
# perspective_coeffs are computed as endpoint -> start point
# We have to invert perspective_coeffs for bboxes:
# (x, y) - end point and (x_out, y_out) - start point
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
# and we would like to get:
# x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2])
# / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
# y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5])
# / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
# and compute inv_coeffs in terms of coeffs
denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3]
if denom == 0:
raise RuntimeError(
f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. "
f"Denominator is zero, denom={denom}"
)
inv_coeffs = [
(perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom,
(-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom,
(perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom,
(-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom,
(perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom,
(-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom,
(-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom,
(-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
]
theta1 = torch.tensor(
[[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
dtype=dtype,
device=device,
)
theta2 = torch.tensor(
[[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device
)
# 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
# Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
# 2) Now let's transform the points using perspective matrices
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
numer_points = torch.matmul(points, theta1.T)
denom_points = torch.matmul(points, theta2.T)
transformed_points = numer_points.div_(denom_points)
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
# and compute bounding box from 4 transformed points:
transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
out_bboxes = clamp_bounding_box(
torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype),
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
)
# out_bboxes should be of shape [N boxes, 4]
return convert_format_bounding_box(
out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
).reshape(original_shape)
def perspective_mask(
mask: torch.Tensor,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
fill: datapoints._FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = perspective_image_tensor(
mask, startpoints, endpoints, interpolation=InterpolationMode.NEAREST, fill=fill, coefficients=coefficients
)
if needs_squeeze:
output = output.squeeze(0)
return output
def perspective_video(
video: torch.Tensor,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> torch.Tensor:
return perspective_image_tensor(
video, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
)
def perspective(
inpt: datapoints._InputTypeJIT,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return perspective_image_tensor(
inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.perspective(
startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
)
elif isinstance(inpt, PIL.Image.Image):
return perspective_image_pil(
inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def elastic_image_tensor(
image: torch.Tensor,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None,
) -> torch.Tensor:
interpolation = _check_interpolation(interpolation)
if image.numel() == 0:
return image
shape = image.shape
ndim = image.ndim
device = image.device
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
# Patch: elastic transform should support (cpu,f16) input
is_cpu_half = device.type == "cpu" and dtype == torch.float16
if is_cpu_half:
image = image.to(torch.float32)
dtype = torch.float32
# We are aware that if input image dtype is uint8 and displacement is float64 then
# displacement will be casted to float32 and all computations will be done with float32
# We can fix this later if needed
expected_shape = (1,) + shape[-2:] + (2,)
if expected_shape != displacement.shape:
raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
elif ndim == 3:
image = image.unsqueeze(0)
needs_unsquash = True
else:
needs_unsquash = False
if displacement.dtype != dtype or displacement.device != device:
displacement = displacement.to(dtype=dtype, device=device)
image_height, image_width = shape[-2:]
grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
if needs_unsquash:
output = output.reshape(shape)
if is_cpu_half:
output = output.to(torch.float16)
return output
@torch.jit.unused
def elastic_image_pil(
image: PIL.Image.Image,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None,
) -> PIL.Image.Image:
t_img = pil_to_tensor(image)
output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill)
return to_pil_image(output, mode=image.mode)
def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
sy, sx = size
base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype)
x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
return base_grid
def elastic_bounding_box(
bounding_box: torch.Tensor,
format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int],
displacement: torch.Tensor,
) -> torch.Tensor:
if bounding_box.numel() == 0:
return bounding_box
# TODO: add in docstring about approximation we are doing for grid inversion
device = bounding_box.device
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
if displacement.dtype != dtype or displacement.device != device:
displacement = displacement.to(dtype=dtype, device=device)
original_shape = bounding_box.shape
# TODO: first cast to float if bbox is int64 before convert_format_bounding_box
bounding_box = (
convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
).reshape(-1, 4)
id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid
inv_grid = id_grid.sub_(displacement)
# Get points from bboxes
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
if points.is_floating_point():
points = points.ceil_()
index_xy = points.to(dtype=torch.long)
index_x, index_y = index_xy[:, 0], index_xy[:, 1]
# Transform points:
t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype)
transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
out_bboxes = clamp_bounding_box(
torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype),
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
)
return convert_format_bounding_box(
out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
).reshape(original_shape)
def elastic_mask(
mask: torch.Tensor,
displacement: torch.Tensor,
fill: datapoints._FillTypeJIT = None,
) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill)
if needs_squeeze:
output = output.squeeze(0)
return output
def elastic_video(
video: torch.Tensor,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None,
) -> torch.Tensor:
return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
def elastic(
inpt: datapoints._InputTypeJIT,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None,
) -> datapoints._InputTypeJIT:
if not isinstance(displacement, torch.Tensor):
raise TypeError("Argument displacement should be a Tensor")
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.elastic(displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, PIL.Image.Image):
return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
elastic_transform = elastic
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
if isinstance(output_size, numbers.Number):
s = int(output_size)
return [s, s]
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
return [output_size[0], output_size[0]]
else:
return list(output_size)
def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]:
return [
(crop_width - image_width) // 2 if crop_width > image_width else 0,
(crop_height - image_height) // 2 if crop_height > image_height else 0,
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
def _center_crop_compute_crop_anchor(
crop_height: int, crop_width: int, image_height: int, image_width: int
) -> Tuple[int, int]:
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return crop_top, crop_left
def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
shape = image.shape
if image.numel() == 0:
return image.reshape(shape[:-2] + (crop_height, crop_width))
image_height, image_width = shape[-2:]
if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0)
image_height, image_width = image.shape[-2:]
if crop_width == image_width and crop_height == image_height:
return image
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]
@torch.jit.unused
def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_height, image_width = get_spatial_size_image_pil(image)
if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = pad_image_pil(image, padding_ltrb, fill=0)
image_height, image_width = get_spatial_size_image_pil(image)
if crop_width == image_width and crop_height == image_height:
return image
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
return crop_image_pil(image, crop_top, crop_left, crop_height, crop_width)
def center_crop_bounding_box(
bounding_box: torch.Tensor,
format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int],
output_size: List[int],
) -> Tuple[torch.Tensor, Tuple[int, int]]:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *spatial_size)
return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width)
def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = center_crop_image_tensor(image=mask, output_size=output_size)
if needs_squeeze:
output = output.squeeze(0)
return output
def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor:
return center_crop_image_tensor(video, output_size)
def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return center_crop_image_tensor(inpt, output_size)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.center_crop(output_size)
elif isinstance(inpt, PIL.Image.Image):
return center_crop_image_pil(inpt, output_size)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def resized_crop_image_tensor(
image: torch.Tensor,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> torch.Tensor:
image = crop_image_tensor(image, top, left, height, width)
return resize_image_tensor(image, size, interpolation=interpolation, antialias=antialias)
@torch.jit.unused
def resized_crop_image_pil(
image: PIL.Image.Image,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
) -> PIL.Image.Image:
image = crop_image_pil(image, top, left, height, width)
return resize_image_pil(image, size, interpolation=interpolation)
def resized_crop_bounding_box(
bounding_box: torch.Tensor,
format: datapoints.BoundingBoxFormat,
top: int,
left: int,
height: int,
width: int,
size: List[int],
) -> Tuple[torch.Tensor, Tuple[int, int]]:
bounding_box, _ = crop_bounding_box(bounding_box, format, top, left, height, width)
return resize_bounding_box(bounding_box, spatial_size=(height, width), size=size)
def resized_crop_mask(
mask: torch.Tensor,
top: int,
left: int,
height: int,
width: int,
size: List[int],
) -> torch.Tensor:
mask = crop_mask(mask, top, left, height, width)
return resize_mask(mask, size)
def resized_crop_video(
video: torch.Tensor,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> torch.Tensor:
return resized_crop_image_tensor(
video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
)
def resized_crop(
inpt: datapoints._InputTypeJIT,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return resized_crop_image_tensor(
inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
elif isinstance(inpt, PIL.Image.Image):
return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _parse_five_crop_size(size: List[int]) -> List[int]:
if isinstance(size, numbers.Number):
s = int(size)
size = [s, s]
elif isinstance(size, (tuple, list)) and len(size) == 1:
s = size[0]
size = [s, s]
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
return size
def five_crop_image_tensor(
image: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
crop_height, crop_width = _parse_five_crop_size(size)
image_height, image_width = image.shape[-2:]
if crop_width > image_width or crop_height > image_height:
raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
tl = crop_image_tensor(image, 0, 0, crop_height, crop_width)
tr = crop_image_tensor(image, 0, image_width - crop_width, crop_height, crop_width)
bl = crop_image_tensor(image, image_height - crop_height, 0, crop_height, crop_width)
br = crop_image_tensor(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
center = center_crop_image_tensor(image, [crop_height, crop_width])
return tl, tr, bl, br, center
@torch.jit.unused
def five_crop_image_pil(
image: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
crop_height, crop_width = _parse_five_crop_size(size)
image_height, image_width = get_spatial_size_image_pil(image)
if crop_width > image_width or crop_height > image_height:
raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
tl = crop_image_pil(image, 0, 0, crop_height, crop_width)
tr = crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)
bl = crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width)
br = crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
center = center_crop_image_pil(image, [crop_height, crop_width])
return tl, tr, bl, br, center
def five_crop_video(
video: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return five_crop_image_tensor(video, size)
ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]
def five_crop(
inpt: ImageOrVideoTypeJIT, size: List[int]
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return five_crop_image_tensor(inpt, size)
elif isinstance(inpt, datapoints.Image):
output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size)
return tuple(datapoints.Image.wrap_like(inpt, item) for item in output) # type: ignore[return-value]
elif isinstance(inpt, datapoints.Video):
output = five_crop_video(inpt.as_subclass(torch.Tensor), size)
return tuple(datapoints.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value]
elif isinstance(inpt, PIL.Image.Image):
return five_crop_image_pil(inpt, size)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def ten_crop_image_tensor(
image: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
non_flipped = five_crop_image_tensor(image, size)
if vertical_flip:
image = vertical_flip_image_tensor(image)
else:
image = horizontal_flip_image_tensor(image)
flipped = five_crop_image_tensor(image, size)
return non_flipped + flipped
@torch.jit.unused
def ten_crop_image_pil(
image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
) -> Tuple[
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
]:
non_flipped = five_crop_image_pil(image, size)
if vertical_flip:
image = vertical_flip_image_pil(image)
else:
image = horizontal_flip_image_pil(image)
flipped = five_crop_image_pil(image, size)
return non_flipped + flipped
def ten_crop_video(
video: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip)
def ten_crop(
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], size: List[int], vertical_flip: bool = False
) -> Tuple[
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
]:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
elif isinstance(inpt, datapoints.Image):
output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
return tuple(datapoints.Image.wrap_like(inpt, item) for item in output) # type: ignore[return-value]
elif isinstance(inpt, datapoints.Video):
output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
return tuple(datapoints.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value]
elif isinstance(inpt, PIL.Image.Image):
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
from typing import List, Optional, Tuple, Union
import PIL.Image
import torch
from util import datapoints
from util.datapoints import BoundingBoxFormat
from transforms import _functional_pil as _FP
from transforms._functional_tensor import _max_value
from ._utils import is_simple_tensor
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
chw = list(image.shape[-3:])
ndims = len(chw)
if ndims == 3:
return chw
elif ndims == 2:
chw.insert(0, 1)
return chw
else:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
get_dimensions_image_pil = _FP.get_dimensions
def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_dimensions_image_tensor(inpt)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
channels = inpt.num_channels
height, width = inpt.spatial_size
return [channels, height, width]
elif isinstance(inpt, PIL.Image.Image):
return get_dimensions_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def get_num_channels_image_tensor(image: torch.Tensor) -> int:
chw = image.shape[-3:]
ndims = len(chw)
if ndims == 3:
return chw[0]
elif ndims == 2:
return 1
else:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
get_num_channels_image_pil = _FP.get_image_num_channels
def get_num_channels_video(video: torch.Tensor) -> int:
return get_num_channels_image_tensor(video)
def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_num_channels_image_tensor(inpt)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
return inpt.num_channels
elif isinstance(inpt, PIL.Image.Image):
return get_num_channels_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
# deprecating the old names.
get_image_num_channels = get_num_channels
def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]:
hw = list(image.shape[-2:])
ndims = len(hw)
if ndims == 2:
return hw
else:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
@torch.jit.unused
def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]:
width, height = _FP.get_image_size(image)
return [height, width]
def get_spatial_size_video(video: torch.Tensor) -> List[int]:
return get_spatial_size_image_tensor(video)
def get_spatial_size_mask(mask: torch.Tensor) -> List[int]:
return get_spatial_size_image_tensor(mask)
@torch.jit.unused
def get_spatial_size_bounding_box(bounding_box: datapoints.BoundingBox) -> List[int]:
return list(bounding_box.spatial_size)
def get_spatial_size(inpt: datapoints._InputTypeJIT) -> List[int]:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_spatial_size_image_tensor(inpt)
elif isinstance(inpt, (datapoints.Image, datapoints.Video, datapoints.BoundingBox, datapoints.Mask)):
return list(inpt.spatial_size)
elif isinstance(inpt, PIL.Image.Image):
return get_spatial_size_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def get_num_frames_video(video: torch.Tensor) -> int:
return video.shape[-4]
def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_num_frames_video(inpt)
elif isinstance(inpt, datapoints.Video):
return inpt.num_frames
else:
raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.")
def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor:
xyxy = xywh if inplace else xywh.clone()
xyxy[..., 2:] += xyxy[..., :2]
return xyxy
def _xyxy_to_xywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
xywh = xyxy if inplace else xyxy.clone()
xywh[..., 2:] -= xywh[..., :2]
return xywh
def _cxcywh_to_xyxy(cxcywh: torch.Tensor, inplace: bool) -> torch.Tensor:
if not inplace:
cxcywh = cxcywh.clone()
# Trick to do fast division by 2 and ceil, without casting. It produces the same result as
# `torchvision.ops._box_convert._box_cxcywh_to_xyxy`.
half_wh = cxcywh[..., 2:].div(-2, rounding_mode=None if cxcywh.is_floating_point() else "floor").abs_()
# (cx - width / 2) = x1, same for y1
cxcywh[..., :2].sub_(half_wh)
# (x1 + width) = x2, same for y2
cxcywh[..., 2:].add_(cxcywh[..., :2])
return cxcywh
def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
if not inplace:
xyxy = xyxy.clone()
# (x2 - x1) = width, same for height
xyxy[..., 2:].sub_(xyxy[..., :2])
# (x1 * 2 + width) / 2 = x1 + width / 2 = x1 + (x2-x1)/2 = (x1 + x2)/2 = cx, same for cy
xyxy[..., :2].mul_(2).add_(xyxy[..., 2:]).div_(2, rounding_mode=None if xyxy.is_floating_point() else "floor")
return xyxy
def _convert_format_bounding_box(
bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False
) -> torch.Tensor:
if new_format == old_format:
return bounding_box
# TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance
if old_format == BoundingBoxFormat.XYWH:
bounding_box = _xywh_to_xyxy(bounding_box, inplace)
elif old_format == BoundingBoxFormat.CXCYWH:
bounding_box = _cxcywh_to_xyxy(bounding_box, inplace)
if new_format == BoundingBoxFormat.XYWH:
bounding_box = _xyxy_to_xywh(bounding_box, inplace)
elif new_format == BoundingBoxFormat.CXCYWH:
bounding_box = _xyxy_to_cxcywh(bounding_box, inplace)
return bounding_box
def convert_format_bounding_box(
inpt: datapoints._InputTypeJIT,
old_format: Optional[BoundingBoxFormat] = None,
new_format: Optional[BoundingBoxFormat] = None,
inplace: bool = False,
) -> datapoints._InputTypeJIT:
# This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
# inputs as well as extract it from `datapoints.BoundingBox` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value.
if new_format is None:
raise TypeError("convert_format_bounding_box() missing 1 required argument: 'new_format'")
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if old_format is None:
raise ValueError("For simple tensor inputs, `old_format` has to be passed.")
return _convert_format_bounding_box(inpt, old_format=old_format, new_format=new_format, inplace=inplace)
elif isinstance(inpt, datapoints.BoundingBox):
if old_format is not None:
raise ValueError("For bounding box datapoint inputs, `old_format` must not be passed.")
output = _convert_format_bounding_box(
inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace
)
return datapoints.BoundingBox.wrap_like(inpt, output, format=new_format)
else:
raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
)
def _clamp_bounding_box(
bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int]
) -> torch.Tensor:
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
in_dtype = bounding_box.dtype
bounding_box = bounding_box.clone() if bounding_box.is_floating_point() else bounding_box.float()
xyxy_boxes = convert_format_bounding_box(
bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
)
xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0])
out_boxes = convert_format_bounding_box(
xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True
)
return out_boxes.to(in_dtype)
def clamp_bounding_box(
inpt: datapoints._InputTypeJIT,
format: Optional[BoundingBoxFormat] = None,
spatial_size: Optional[Tuple[int, int]] = None,
) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if format is None or spatial_size is None:
raise ValueError("For simple tensor inputs, `format` and `spatial_size` has to be passed.")
return _clamp_bounding_box(inpt, format=format, spatial_size=spatial_size)
elif isinstance(inpt, datapoints.BoundingBox):
if format is not None or spatial_size is not None:
raise ValueError("For bounding box datapoint inputs, `format` and `spatial_size` must not be passed.")
output = _clamp_bounding_box(inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size)
return datapoints.BoundingBox.wrap_like(inpt, output)
else:
raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
)
def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 8
elif dtype == torch.int8:
return 7
elif dtype == torch.int16:
return 15
elif dtype == torch.int32:
return 31
elif dtype == torch.int64:
return 63
else:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")
def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
if image.dtype == dtype:
return image
float_input = image.is_floating_point()
if torch.jit.is_scripting():
# TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
float_output = torch.tensor(0, dtype=dtype).is_floating_point()
else:
float_output = dtype.is_floating_point
if float_input:
# float to float
if float_output:
return image.to(dtype)
# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")
# For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
# to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
# be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# for a detailed analysis.
# To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
# Instead, we can also multiply by the maximum value plus something close to `1`. See
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
eps = 1e-3
max_value = float(_max_value(dtype))
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
# discrete set `{0, 1}`.
return image.mul(max_value + 1.0 - eps).to(dtype)
else:
# int to float
if float_output:
return image.to(dtype).mul_(1.0 / _max_value(image.dtype))
# int to int
num_value_bits_input = _num_value_bits(image.dtype)
num_value_bits_output = _num_value_bits(dtype)
if num_value_bits_input > num_value_bits_output:
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
else:
return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)
# We changed the name to align it with the new naming scheme. Still, `convert_image_dtype` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
convert_image_dtype = convert_dtype_image_tensor
def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
return convert_dtype_image_tensor(video, dtype)
def convert_dtype(
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], dtype: torch.dtype = torch.float
) -> torch.Tensor:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return convert_dtype_image_tensor(inpt, dtype)
elif isinstance(inpt, datapoints.Image):
output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype)
return datapoints.Image.wrap_like(inpt, output)
elif isinstance(inpt, datapoints.Video):
output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype)
return datapoints.Video.wrap_like(inpt, output)
else:
raise TypeError(
f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead."
)
import math
from typing import List, Optional, Union
import PIL.Image
import torch
from torch.nn.functional import conv2d, pad as torch_pad
from util import datapoints
from transforms.functional import pil_to_tensor, to_pil_image
from ._utils import is_simple_tensor
def normalize_image_tensor(
image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False
) -> torch.Tensor:
if not image.is_floating_point():
raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.")
if image.ndim < 3:
raise ValueError(f"Expected tensor to be a tensor image of size (..., C, H, W). Got {image.shape}.")
if isinstance(std, (tuple, list)):
divzero = not all(std)
elif isinstance(std, (int, float)):
divzero = std == 0
else:
divzero = False
if divzero:
raise ValueError("std evaluated to zero, leading to division by zero.")
dtype = image.dtype
device = image.device
mean = torch.as_tensor(mean, dtype=dtype, device=device)
std = torch.as_tensor(std, dtype=dtype, device=device)
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
if inplace:
image = image.sub_(mean)
else:
image = image.sub(mean)
return image.div_(std)
def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
return normalize_image_tensor(video, mean, std, inplace=inplace)
def normalize(
inpt: Union[datapoints._TensorImageTypeJIT, datapoints._TensorVideoTypeJIT],
mean: List[float],
std: List[float],
inplace: bool = False,
) -> torch.Tensor:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
return inpt.normalize(mean=mean, std=std, inplace=inplace)
else:
raise TypeError(
f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead."
)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0) * sigma)
x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device)
kernel1d = torch.softmax(x.pow_(2).neg_(), dim=0)
return kernel1d
def _get_gaussian_kernel2d(
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype, device)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype, device)
kernel2d = kernel1d_y.unsqueeze(-1) * kernel1d_x
return kernel2d
def gaussian_blur_image_tensor(
image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor:
# TODO: consider deprecating integers from sigma on the future
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
elif len(kernel_size) != 2:
raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
for ksize in kernel_size:
if ksize % 2 == 0 or ksize < 0:
raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
if sigma is None:
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
else:
if isinstance(sigma, (list, tuple)):
length = len(sigma)
if length == 1:
s = float(sigma[0])
sigma = [s, s]
elif length != 2:
raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}")
elif isinstance(sigma, (int, float)):
s = float(sigma)
sigma = [s, s]
else:
raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
for s in sigma:
if s <= 0.0:
raise ValueError(f"sigma should have positive values. Got {sigma}")
if image.numel() == 0:
return image
dtype = image.dtype
shape = image.shape
ndim = image.ndim
if ndim == 3:
image = image.unsqueeze(dim=0)
elif ndim > 4:
image = image.reshape((-1,) + shape[-3:])
fp = torch.is_floating_point(image)
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype if fp else torch.float32, device=image.device)
kernel = kernel.expand(shape[-3], 1, kernel.shape[0], kernel.shape[1])
output = image if fp else image.to(dtype=torch.float32)
# padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
output = torch_pad(output, padding, mode="reflect")
output = conv2d(output, kernel, groups=shape[-3])
if ndim == 3:
output = output.squeeze(dim=0)
elif ndim > 4:
output = output.reshape(shape)
if not fp:
output = output.round_().to(dtype=dtype)
return output
@torch.jit.unused
def gaussian_blur_image_pil(
image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> PIL.Image.Image:
t_img = pil_to_tensor(image)
output = gaussian_blur_image_tensor(t_img, kernel_size=kernel_size, sigma=sigma)
return to_pil_image(output, mode=image.mode)
def gaussian_blur_video(
video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor:
return gaussian_blur_image_tensor(video, kernel_size, sigma)
def gaussian_blur(
inpt: datapoints._InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, datapoints.Datapoint):
return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, PIL.Image.Image):
return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
import torch
from util import datapoints
from ._utils import is_simple_tensor
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
t_max = video.shape[-4] - 1
indices = torch.linspace(0, t_max, num_samples, device=video.device).long()
return torch.index_select(video, -4, indices)
def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) -> datapoints._VideoTypeJIT:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return uniform_temporal_subsample_video(inpt, num_samples)
elif isinstance(inpt, datapoints.Video):
output = uniform_temporal_subsample_video(inpt.as_subclass(torch.Tensor), num_samples)
return datapoints.Video.wrap_like(inpt, output)
else:
raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.")
from typing import Union
import numpy as np
import PIL.Image
import torch
from util import datapoints
from transforms import functional as _F
@torch.jit.unused
def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoints.Image:
if isinstance(inpt, np.ndarray):
output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous()
elif isinstance(inpt, PIL.Image.Image):
output = pil_to_tensor(inpt)
elif isinstance(inpt, torch.Tensor):
output = inpt
else:
raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.")
return datapoints.Image(output)
to_image_pil = _F.to_pil_image
pil_to_tensor = _F.pil_to_tensor
# We changed the names to align them with the new naming scheme. Still, `to_pil_image` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
to_pil_image = to_image_pil
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment