"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "9d4571e3d0bddb7e49852909eddcf5b4145c1a91"
Unverified Commit 9effc4cd authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Added some transformations and fixed type hints (#6245)

* Another attempt to add transforms

* Fixed padding type hint

* Fixed fill arg for pad and rotate, affine

* code formatting and type hints for affine transformation

* Fixed flake8

* Updated tests to save and load transforms

* Fixed code formatting issue

* Fixed jit loading issue

* Restored fill default value to None
Updated code according to the review

* Added tests for rotation, affine and zoom transforms

* Put back commented code

* Random erase bypass boxes and masks
Go back with if-return/elif-return/else-return

* Fixed acceptable and non-acceptable types for Cutmix/Mixup

* Updated conditions for _BaseMixupCutmix
parent e75a3337
......@@ -955,18 +955,7 @@ def test_adjust_gamma(device, dtype, config, channels):
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize(
"pad",
[
2,
[
3,
],
[0, 3],
(3, 3),
[4, 2, 4, 3],
],
)
@pytest.mark.parametrize("pad", [2, [3], [0, 3], (3, 3), [4, 2, 4, 3]])
@pytest.mark.parametrize(
"config",
[
......
......@@ -3,7 +3,11 @@ import itertools
import pytest
import torch
from common_utils import assert_equal
from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels
from test_prototype_transforms_functional import (
make_images,
make_bounding_boxes,
make_one_hot_labels,
)
from torchvision.prototype import transforms, features
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
......@@ -72,6 +76,9 @@ class TestSmoke:
transforms.ConvertImageDtype(),
transforms.RandomHorizontalFlip(),
transforms.Pad(5),
transforms.RandomZoomOut(),
transforms.RandomRotation(degrees=(-45, 45)),
transforms.RandomAffine(degrees=(-45, 45)),
)
def test_common(self, transform, input):
transform(input)
......
......@@ -317,7 +317,7 @@ def rotate_image_tensor():
[-87, 15, 90], # angle
[True, False], # expand
[None, [12, 23]], # center
[None, [128]], # fill
[None, [128], [12.0]], # fill
):
if center is not None and expand:
# Skip warning: The provided center argument is ignored if expand is True
......
......@@ -128,13 +128,20 @@ class BoundingBox(_Feature):
return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype)
def pad(
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant",
) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F
if padding_mode not in ["constant"]:
raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
output = _F.pad_bounding_box(self, padding, format=self.format)
# Update output image size:
......@@ -153,7 +160,7 @@ class BoundingBox(_Feature):
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F
......@@ -173,7 +180,7 @@ class BoundingBox(_Feature):
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F
......@@ -194,18 +201,9 @@ class BoundingBox(_Feature):
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F
output = _F.perspective_bounding_box(self, self.format, perspective_coeffs)
return BoundingBox.new_like(self, output, dtype=output.dtype)
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> BoundingBox:
raise TypeError("Erase transformation does not support bounding boxes")
def mixup(self, lam: float) -> BoundingBox:
raise TypeError("Mixup transformation does not support bounding boxes")
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> BoundingBox:
raise TypeError("Cutmix transformation does not support bounding boxes")
......@@ -120,7 +120,10 @@ class _Feature(torch.Tensor):
return self
def pad(
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant",
) -> Any:
return self
......@@ -129,7 +132,7 @@ class _Feature(torch.Tensor):
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> Any:
return self
......@@ -141,7 +144,7 @@ class _Feature(torch.Tensor):
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> Any:
return self
......@@ -150,7 +153,7 @@ class _Feature(torch.Tensor):
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Any:
return self
......@@ -186,12 +189,3 @@ class _Feature(torch.Tensor):
def invert(self) -> Any:
return self
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> Any:
return self
def mixup(self, lam: float) -> Any:
return self
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> Any:
return self
......@@ -164,10 +164,20 @@ class Image(_Feature):
return Image.new_like(self, output)
def pad(
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant",
) -> Image:
from torchvision.prototype.transforms import functional as _F
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
if fill is None:
fill = 0
# PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
if isinstance(fill, (int, float)):
output = _F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
......@@ -183,10 +193,12 @@ class Image(_Feature):
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> Image:
from torchvision.prototype.transforms import functional as _F
from torchvision.prototype.transforms.functional import _geometry as _F
fill = _F._convert_fill_arg(fill)
output = _F.rotate_image_tensor(
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
......@@ -200,10 +212,12 @@ class Image(_Feature):
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> Image:
from torchvision.prototype.transforms import functional as _F
from torchvision.prototype.transforms.functional import _geometry as _F
fill = _F._convert_fill_arg(fill)
output = _F.affine_image_tensor(
self,
......@@ -221,9 +235,11 @@ class Image(_Feature):
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image:
from torchvision.prototype.transforms import functional as _F
from torchvision.prototype.transforms.functional import _geometry as _F
fill = _F._convert_fill_arg(fill)
output = _F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill)
return Image.new_like(self, output)
......@@ -293,25 +309,3 @@ class Image(_Feature):
output = _F.invert_image_tensor(self)
return Image.new_like(self, output)
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.erase_image_tensor(self, i, j, h, w, v)
return Image.new_like(self, output)
def mixup(self, lam: float) -> Image:
if self.ndim < 4:
raise ValueError("Need a batch of images")
output = self.clone()
output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam))
return Image.new_like(self, output)
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> Image:
if self.ndim < 4:
raise ValueError("Need a batch of images")
x1, y1, x2, y2 = box
image_rolled = self.roll(1, -4)
output = self.clone()
output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
return Image.new_like(self, output)
from __future__ import annotations
from typing import Any, Optional, Sequence, cast, Union, Tuple
from typing import Any, Optional, Sequence, cast, Union
import torch
from torchvision.prototype.utils._internal import apply_recursively
......@@ -77,14 +77,3 @@ class OneHotLabel(_Feature):
return super().new_like(
other, data, categories=categories if categories is not None else other.categories, **kwargs
)
def mixup(self, lam: float) -> OneHotLabel:
if self.ndim < 2:
raise ValueError("Need a batch of one hot labels")
output = self.clone()
output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam))
return OneHotLabel.new_like(self, output)
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> OneHotLabel:
box # unused
return self.mixup(lam_adjusted)
from __future__ import annotations
from typing import Tuple, List, Optional, Union, Sequence
from typing import List, Optional, Union, Sequence
import torch
from torchvision.transforms import InterpolationMode
from ._feature import _Feature
......@@ -61,10 +60,17 @@ class SegmentationMask(_Feature):
return SegmentationMask.new_like(self, output)
def pad(
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant",
) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
output = _F.pad_segmentation_mask(self, padding, padding_mode=padding_mode)
return SegmentationMask.new_like(self, output)
......@@ -73,7 +79,7 @@ class SegmentationMask(_Feature):
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F
......@@ -88,7 +94,7 @@ class SegmentationMask(_Feature):
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F
......@@ -107,18 +113,9 @@ class SegmentationMask(_Feature):
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F
output = _F.perspective_segmentation_mask(self, perspective_coeffs)
return SegmentationMask.new_like(self, output)
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> SegmentationMask:
raise TypeError("Erase transformation does not support segmentation masks")
def mixup(self, lam: float) -> SegmentationMask:
raise TypeError("Mixup transformation does not support segmentation masks")
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> SegmentationMask:
raise TypeError("Cutmix transformation does not support segmentation masks")
......@@ -17,6 +17,8 @@ from ._geometry import (
RandomVerticalFlip,
Pad,
RandomZoomOut,
RandomRotation,
RandomAffine,
)
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda
......
......@@ -9,7 +9,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
from ._transform import _RandomApplyTransform
from ._utils import query_image, get_image_dimensions, has_all
from ._utils import query_image, get_image_dimensions, has_any, has_all
class RandomErasing(_RandomApplyTransform):
......@@ -86,13 +86,14 @@ class RandomErasing(_RandomApplyTransform):
return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features._Feature):
return inpt.erase(**params)
if isinstance(inpt, (features.Image, torch.Tensor)):
output = F.erase_image_tensor(inpt, **params)
if isinstance(inpt, features.Image):
return features.Image.new_like(inpt, output)
return output
elif isinstance(inpt, PIL.Image.Image):
# TODO: We should implement a fallback to tensor, like gaussian_blur etc
raise RuntimeError("Not implemented")
elif isinstance(inpt, torch.Tensor):
return F.erase_image_tensor(inpt, **params)
else:
return inpt
......@@ -107,16 +108,34 @@ class _BaseMixupCutmix(Transform):
sample = inpts if len(inpts) > 1 else inpts[0]
if not has_all(sample, features.Image, features.OneHotLabel):
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label):
raise TypeError(
f"{type(self).__name__}() does not support bounding boxes, segmentation masks and plain labels."
)
return super().forward(sample)
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
if inpt.ndim < 2:
raise ValueError("Need a batch of one hot labels")
output = inpt.clone()
output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam))
return features.OneHotLabel.new_like(inpt, output)
class RandomMixup(_BaseMixupCutmix):
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(())))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features._Feature):
return inpt.mixup(**params)
lam = params["lam"]
if isinstance(inpt, features.Image):
if inpt.ndim < 4:
raise ValueError("Need a batch of images")
output = inpt.clone()
output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam))
return features.Image.new_like(inpt, output)
elif isinstance(inpt, features.OneHotLabel):
return self._mixup_onehotlabel(inpt, lam)
else:
return inpt
......@@ -146,7 +165,17 @@ class RandomCutmix(_BaseMixupCutmix):
return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features._Feature):
return inpt.cutmix(**params)
if isinstance(inpt, features.Image):
box = params["box"]
if inpt.ndim < 4:
raise ValueError("Need a batch of images")
x1, y1, x2, y2 = box
image_rolled = inpt.roll(1, -4)
output = inpt.clone()
output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
return features.Image.new_like(inpt, output)
elif isinstance(inpt, features.OneHotLabel):
lam_adjusted = params["lam_adjusted"]
return self._mixup_onehotlabel(inpt, lam_adjusted)
else:
return inpt
import math
from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar, Union, Type
from typing import Any, Dict, Tuple, Optional, Callable, List, cast, Sequence, TypeVar, Union, Type
import PIL.Image
import torch
......@@ -9,7 +9,7 @@ from torchvision.prototype.utils._internal import query_recursively
from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms.functional import pil_to_tensor, to_pil_image, InterpolationMode
from ._utils import get_image_dimensions, is_simple_tensor
from ._utils import get_image_dimensions
K = TypeVar("K")
V = TypeVar("V")
......@@ -29,7 +29,10 @@ def _put_into_sample(sample: Any, id: Tuple[Any, ...], item: Any) -> Any:
class _AutoAugmentBase(Transform):
def __init__(
self, *, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None
self,
*,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
) -> None:
super().__init__()
self.interpolation = interpolation
......@@ -66,7 +69,7 @@ class _AutoAugmentBase(Transform):
def _parse_fill(
self, image: Union[PIL.Image.Image, torch.Tensor, features.Image], num_channels: int
) -> Optional[List[float]]:
) -> Union[int, float, Sequence[int], Sequence[float]]:
fill = self.fill
if isinstance(image, PIL.Image.Image) or fill is None:
......@@ -79,36 +82,18 @@ class _AutoAugmentBase(Transform):
return fill
def _dispatch_image_kernels(
self,
image_tensor_kernel: Callable,
image_pil_kernel: Callable,
input: Any,
*args: Any,
**kwargs: Any,
) -> Any:
if isinstance(input, features.Image):
output = image_tensor_kernel(input, *args, **kwargs)
return features.Image.new_like(input, output)
elif is_simple_tensor(input):
return image_tensor_kernel(input, *args, **kwargs)
else: # isinstance(input, PIL.Image.Image):
return image_pil_kernel(input, *args, **kwargs)
def _apply_image_transform(
self,
image: Any,
transform_id: str,
magnitude: float,
interpolation: InterpolationMode,
fill: Optional[List[float]],
fill: Union[int, float, Sequence[int], Sequence[float]],
) -> Any:
if transform_id == "Identity":
return image
elif transform_id == "ShearX":
return self._dispatch_image_kernels(
F.affine_image_tensor,
F.affine_image_pil,
return F.affine(
image,
angle=0.0,
translate=[0, 0],
......@@ -118,9 +103,7 @@ class _AutoAugmentBase(Transform):
fill=fill,
)
elif transform_id == "ShearY":
return self._dispatch_image_kernels(
F.affine_image_tensor,
F.affine_image_pil,
return F.affine(
image,
angle=0.0,
translate=[0, 0],
......@@ -130,9 +113,7 @@ class _AutoAugmentBase(Transform):
fill=fill,
)
elif transform_id == "TranslateX":
return self._dispatch_image_kernels(
F.affine_image_tensor,
F.affine_image_pil,
return F.affine(
image,
angle=0.0,
translate=[int(magnitude), 0],
......@@ -142,9 +123,7 @@ class _AutoAugmentBase(Transform):
fill=fill,
)
elif transform_id == "TranslateY":
return self._dispatch_image_kernels(
F.affine_image_tensor,
F.affine_image_pil,
return F.affine(
image,
angle=0.0,
translate=[0, int(magnitude)],
......@@ -154,46 +133,25 @@ class _AutoAugmentBase(Transform):
fill=fill,
)
elif transform_id == "Rotate":
return self._dispatch_image_kernels(F.rotate_image_tensor, F.rotate_image_pil, image, angle=magnitude)
return F.rotate(image, angle=magnitude)
elif transform_id == "Brightness":
return self._dispatch_image_kernels(
F.adjust_brightness_image_tensor,
F.adjust_brightness_image_pil,
image,
brightness_factor=1.0 + magnitude,
)
return F.adjust_brightness(image, brightness_factor=1.0 + magnitude)
elif transform_id == "Color":
return self._dispatch_image_kernels(
F.adjust_saturation_image_tensor,
F.adjust_saturation_image_pil,
image,
saturation_factor=1.0 + magnitude,
)
return F.adjust_saturation(image, saturation_factor=1.0 + magnitude)
elif transform_id == "Contrast":
return self._dispatch_image_kernels(
F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, image, contrast_factor=1.0 + magnitude
)
return F.adjust_contrast(image, contrast_factor=1.0 + magnitude)
elif transform_id == "Sharpness":
return self._dispatch_image_kernels(
F.adjust_sharpness_image_tensor,
F.adjust_sharpness_image_pil,
image,
sharpness_factor=1.0 + magnitude,
)
return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude)
elif transform_id == "Posterize":
return self._dispatch_image_kernels(
F.posterize_image_tensor, F.posterize_image_pil, image, bits=int(magnitude)
)
return F.posterize(image, bits=int(magnitude))
elif transform_id == "Solarize":
return self._dispatch_image_kernels(
F.solarize_image_tensor, F.solarize_image_pil, image, threshold=magnitude
)
return F.solarize(image, threshold=magnitude)
elif transform_id == "AutoContrast":
return self._dispatch_image_kernels(F.autocontrast_image_tensor, F.autocontrast_image_pil, image)
return F.autocontrast(image)
elif transform_id == "Equalize":
return self._dispatch_image_kernels(F.equalize_image_tensor, F.equalize_image_pil, image)
return F.equalize(image)
elif transform_id == "Invert":
return self._dispatch_image_kernels(F.invert_image_tensor, F.invert_image_pil, image)
return F.invert(image)
else:
raise ValueError(f"No transform available for {transform_id}")
......@@ -231,7 +189,7 @@ class AutoAugment(_AutoAugmentBase):
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy
......@@ -393,7 +351,7 @@ class RandAugment(_AutoAugmentBase):
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops
......@@ -453,7 +411,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
*,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
......@@ -512,7 +470,7 @@ class AugMix(_AutoAugmentBase):
alpha: float = 1.0,
all_ops: bool = True,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10
......
import collections.abc
import functools
from typing import Any, Dict, Union, Tuple, Optional, Sequence, Callable, TypeVar
from typing import Any, Dict, Union, Tuple, Optional, Sequence, TypeVar
import PIL.Image
import torch
......@@ -53,76 +52,36 @@ class ColorJitter(Transform):
return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))
def _image_transform(
self,
input: T,
*,
kernel_tensor: Callable[..., torch.Tensor],
kernel_pil: Callable[..., PIL.Image.Image],
**kwargs: Any,
) -> T:
if isinstance(input, features.Image):
output = kernel_tensor(input, **kwargs)
return features.Image.new_like(input, output)
elif is_simple_tensor(input):
return kernel_tensor(input, **kwargs)
elif isinstance(input, PIL.Image.Image):
return kernel_pil(input, **kwargs) # type: ignore[no-any-return]
else:
raise RuntimeError
@staticmethod
def _generate_value(left: float, right: float) -> float:
return float(torch.distributions.Uniform(left, right).sample())
def _get_params(self, sample: Any) -> Dict[str, Any]:
image_transforms = []
if self.brightness is not None:
image_transforms.append(
functools.partial(
self._image_transform,
kernel_tensor=F.adjust_brightness_image_tensor,
kernel_pil=F.adjust_brightness_image_pil,
brightness_factor=float(
torch.distributions.Uniform(self.brightness[0], self.brightness[1]).sample()
),
)
)
if self.contrast is not None:
image_transforms.append(
functools.partial(
self._image_transform,
kernel_tensor=F.adjust_contrast_image_tensor,
kernel_pil=F.adjust_contrast_image_pil,
contrast_factor=float(torch.distributions.Uniform(self.contrast[0], self.contrast[1]).sample()),
)
)
if self.saturation is not None:
image_transforms.append(
functools.partial(
self._image_transform,
kernel_tensor=F.adjust_saturation_image_tensor,
kernel_pil=F.adjust_saturation_image_pil,
saturation_factor=float(
torch.distributions.Uniform(self.saturation[0], self.saturation[1]).sample()
),
)
)
if self.hue is not None:
image_transforms.append(
functools.partial(
self._image_transform,
kernel_tensor=F.adjust_hue_image_tensor,
kernel_pil=F.adjust_hue_image_pil,
hue_factor=float(torch.distributions.Uniform(self.hue[0], self.hue[1]).sample()),
)
)
return dict(image_transforms=[image_transforms[idx] for idx in torch.randperm(len(image_transforms))])
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(input, (features.Image, PIL.Image.Image)) or is_simple_tensor(input)):
return input
for transform in params["image_transforms"]:
input = transform(input)
return input
fn_idx = torch.randperm(4)
b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1])
c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1])
s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1])
h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1])
return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = inpt
brightness_factor = params["brightness_factor"]
contrast_factor = params["contrast_factor"]
saturation_factor = params["saturation_factor"]
hue_factor = params["hue_factor"]
for fn_id in params["fn_idx"]:
if fn_id == 0 and brightness_factor is not None:
output = F.adjust_brightness(output, brightness_factor=brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
output = F.adjust_contrast(output, contrast_factor=contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
output = F.adjust_saturation(output, saturation_factor=saturation_factor)
elif fn_id == 3 and hue_factor is not None:
output = F.adjust_hue(output, hue_factor=hue_factor)
return output
class _RandomChannelShuffle(Transform):
......@@ -131,19 +90,19 @@ class _RandomChannelShuffle(Transform):
num_channels, _, _ = get_image_dimensions(image)
return dict(permutation=torch.randperm(num_channels))
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(input, (features.Image, PIL.Image.Image)) or is_simple_tensor(input)):
return input
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt
image = input
if isinstance(input, PIL.Image.Image):
image = inpt
if isinstance(inpt, PIL.Image.Image):
image = _F.pil_to_tensor(image)
output = image[..., params["permutation"], :, :]
if isinstance(input, features.Image):
output = features.Image.new_like(input, output, color_space=features.ColorSpace.OTHER)
elif isinstance(input, PIL.Image.Image):
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER)
elif isinstance(inpt, PIL.Image.Image):
output = _F.to_pil_image(output)
return output
......@@ -175,33 +134,25 @@ class RandomPhotometricDistort(Transform):
contrast_before=torch.rand(()) < 0.5,
)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["brightness"]:
input = self._brightness(input)
inpt = self._brightness(inpt)
if params["contrast1"] and params["contrast_before"]:
input = self._contrast(input)
inpt = self._contrast(inpt)
if params["saturation"]:
input = self._saturation(input)
inpt = self._saturation(inpt)
if params["saturation"]:
input = self._saturation(input)
inpt = self._saturation(inpt)
if params["contrast2"] and not params["contrast_before"]:
input = self._contrast(input)
inpt = self._contrast(inpt)
if params["channel_shuffle"]:
input = self._channel_shuffle(input)
return input
inpt = self._channel_shuffle(inpt)
return inpt
class RandomEqualize(_RandomApplyTransform):
def __init__(self, p: float = 0.5):
super().__init__(p=p)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.equalize_image_tensor(input)
return features.Image.new_like(input, output)
elif is_simple_tensor(input):
return F.equalize_image_tensor(input)
elif isinstance(input, PIL.Image.Image):
return F.equalize_image_pil(input)
else:
return input
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.equalize(inpt)
......@@ -2,14 +2,14 @@ import collections.abc
import math
import numbers
import warnings
from typing import Any, Dict, List, Union, Sequence, Tuple, cast
from typing import Any, Dict, List, Optional, Union, Sequence, Tuple, cast
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
from torchvision.transforms.functional import pil_to_tensor, InterpolationMode
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
from torchvision.transforms.transforms import _setup_size, _setup_angle, _check_sequence_input
from typing_extensions import Literal
from ._transform import _RandomApplyTransform
......@@ -17,41 +17,13 @@ from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor
class RandomHorizontalFlip(_RandomApplyTransform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.horizontal_flip_image_tensor(input)
return features.Image.new_like(input, output)
elif isinstance(input, features.SegmentationMask):
output = F.horizontal_flip_segmentation_mask(input)
return features.SegmentationMask.new_like(input, output)
elif isinstance(input, features.BoundingBox):
output = F.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size)
return features.BoundingBox.new_like(input, output)
elif isinstance(input, PIL.Image.Image):
return F.horizontal_flip_image_pil(input)
elif is_simple_tensor(input):
return F.horizontal_flip_image_tensor(input)
else:
return input
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.horizontal_flip(inpt)
class RandomVerticalFlip(_RandomApplyTransform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.vertical_flip_image_tensor(input)
return features.Image.new_like(input, output)
elif isinstance(input, features.SegmentationMask):
output = F.vertical_flip_segmentation_mask(input)
return features.SegmentationMask.new_like(input, output)
elif isinstance(input, features.BoundingBox):
output = F.vertical_flip_bounding_box(input, format=input.format, image_size=input.image_size)
return features.BoundingBox.new_like(input, output)
elif isinstance(input, PIL.Image.Image):
return F.vertical_flip_image_pil(input)
elif is_simple_tensor(input):
return F.vertical_flip_image_tensor(input)
else:
return input
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.vertical_flip(inpt)
class Resize(Transform):
......@@ -59,27 +31,23 @@ class Resize(Transform):
self,
size: Union[int, Sequence[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> None:
super().__init__()
self.size = [size] if isinstance(size, int) else list(size)
self.interpolation = interpolation
self.max_size = max_size
self.antialias = antialias
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.resize_image_tensor(input, self.size, interpolation=self.interpolation)
return features.Image.new_like(input, output)
elif isinstance(input, features.SegmentationMask):
output = F.resize_segmentation_mask(input, self.size)
return features.SegmentationMask.new_like(input, output)
elif isinstance(input, features.BoundingBox):
output = F.resize_bounding_box(input, self.size, image_size=input.image_size)
return features.BoundingBox.new_like(input, output, image_size=cast(Tuple[int, int], tuple(self.size)))
elif isinstance(input, PIL.Image.Image):
return F.resize_image_pil(input, self.size, interpolation=self.interpolation)
elif is_simple_tensor(input):
return F.resize_image_tensor(input, self.size, interpolation=self.interpolation)
else:
return input
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(
inpt,
self.size,
interpolation=self.interpolation,
max_size=self.max_size,
antialias=self.antialias,
)
class CenterCrop(Transform):
......@@ -87,22 +55,8 @@ class CenterCrop(Transform):
super().__init__()
self.output_size = output_size
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.center_crop_image_tensor(input, self.output_size)
return features.Image.new_like(input, output)
elif is_simple_tensor(input):
return F.center_crop_image_tensor(input, self.output_size)
elif isinstance(input, PIL.Image.Image):
return F.center_crop_image_pil(input, self.output_size)
else:
return input
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.center_crop(inpt, output_size=self.output_size)
class RandomResizedCrop(Transform):
......@@ -112,6 +66,7 @@ class RandomResizedCrop(Transform):
scale: Tuple[float, float] = (0.08, 1.0),
ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
......@@ -125,20 +80,16 @@ class RandomResizedCrop(Transform):
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
self.size = size
self.scale = scale
self.ratio = ratio
self.interpolation = interpolation
self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]:
# vfdev-5: techically, this op can work on bboxes/segm masks only inputs without image in samples
# What if we have multiple images/bboxes/masks of different sizes ?
# TODO: let's support bbox or mask in samples without image
image = query_image(sample)
_, height, width = get_image_dimensions(image)
area = height * width
......@@ -177,24 +128,10 @@ class RandomResizedCrop(Transform):
return dict(top=i, left=j, height=h, width=w)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.resized_crop_image_tensor(
input, **params, size=list(self.size), interpolation=self.interpolation
)
return features.Image.new_like(input, output)
elif is_simple_tensor(input):
return F.resized_crop_image_tensor(input, **params, size=list(self.size), interpolation=self.interpolation)
elif isinstance(input, PIL.Image.Image):
return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation)
else:
return input
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resized_crop(
inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias
)
class MultiCropResult(list):
......@@ -283,19 +220,23 @@ class BatchMultiCrop(Transform):
return apply_recursively(inputs if len(inputs) > 1 else inputs[0])
def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> None:
if not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
class Pad(Transform):
def __init__(
self,
padding: Union[int, Sequence[int]],
fill: Union[float, Sequence[float]] = 0.0,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (numbers.Number, str, tuple, list)):
raise TypeError("Got inappropriate fill arg")
_check_fill_arg(fill)
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
......@@ -309,55 +250,20 @@ class Pad(Transform):
self.fill = fill
self.padding_mode = padding_mode
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image) or is_simple_tensor(input):
# PyTorch's pad supports only integers on fill. So we need to overwrite the colour
output = F.pad_image_tensor(input, params["padding"], fill=0, padding_mode="constant")
left, top, right, bottom = params["padding"]
fill = torch.tensor(params["fill"], dtype=input.dtype, device=input.device).to().view(-1, 1, 1)
if top > 0:
output[..., :top, :] = fill
if left > 0:
output[..., :, :left] = fill
if bottom > 0:
output[..., -bottom:, :] = fill
if right > 0:
output[..., :, -right:] = fill
if isinstance(input, features.Image):
output = features.Image.new_like(input, output)
return output
elif isinstance(input, PIL.Image.Image):
return F.pad_image_pil(
input,
params["padding"],
fill=tuple(int(v) if input.mode != "F" else v for v in params["fill"]),
padding_mode="constant",
)
elif isinstance(input, features.BoundingBox):
output = F.pad_bounding_box(input, params["padding"], format=input.format)
left, top, right, bottom = params["padding"]
height, width = input.image_size
height += top + bottom
width += left + right
return features.BoundingBox.new_like(input, output, image_size=(height, width))
else:
return input
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.pad(inpt, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode)
class RandomZoomOut(_RandomApplyTransform):
def __init__(
self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
self,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
side_range: Tuple[float, float] = (1.0, 4.0),
p: float = 0.5,
) -> None:
super().__init__(p=p)
if fill is None:
fill = 0.0
_check_fill_arg(fill)
self.fill = fill
self.side_range = side_range
......@@ -385,6 +291,126 @@ class RandomZoomOut(_RandomApplyTransform):
return dict(padding=padding, fill=fill)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
transform = Pad(**params, padding_mode="constant")
return transform(input)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.pad(inpt, **params)
class RandomRotation(Transform):
def __init__(
self,
degrees: Union[numbers.Number, Sequence],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
self.interpolation = interpolation
self.expand = expand
_check_fill_arg(fill)
self.fill = fill
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
self.center = center
def _get_params(self, sample: Any) -> Dict[str, Any]:
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
return dict(angle=angle)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.rotate(
inpt,
**params,
interpolation=self.interpolation,
expand=self.expand,
fill=self.fill,
center=self.center,
)
class RandomAffine(Transform):
def __init__(
self,
degrees: Union[numbers.Number, Sequence],
translate: Optional[Sequence[float]] = None,
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[float, Sequence[float]]] = None,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
if translate is not None:
_check_sequence_input(translate, "translate", req_sizes=(2,))
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
_check_sequence_input(scale, "scale", req_sizes=(2,))
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
else:
self.shear = shear
self.interpolation = interpolation
_check_fill_arg(fill)
self.fill = fill
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
self.center = center
def _get_params(self, sample: Any) -> Dict[str, Any]:
# Get image size
# TODO: make it work with bboxes and segm masks
image = query_image(sample)
_, height, width = get_image_dimensions(image)
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
if self.translate is not None:
max_dx = float(self.translate[0] * width)
max_dy = float(self.translate[1] * height)
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
translations = (tx, ty)
else:
translations = (0, 0)
if self.scale is not None:
scale = float(torch.empty(1).uniform_(self.scale[0], self.scale[1]).item())
else:
scale = 1.0
shear_x = shear_y = 0.0
if self.shear is not None:
shear_x = float(torch.empty(1).uniform_(self.shear[0], self.shear[1]).item())
if len(self.shear) == 4:
shear_y = float(torch.empty(1).uniform_(self.shear[2], self.shear[3]).item())
shear = (shear_x, shear_y)
return dict(angle=angle, translations=translations, scale=scale, shear=shear)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.affine(
inpt,
**params,
interpolation=self.interpolation,
fill=self.fill,
center=self.center,
)
......@@ -16,9 +16,10 @@ adjust_brightness_image_pil = _FP.adjust_brightness
def adjust_brightness(inpt: DType, brightness_factor: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_brightness(brightness_factor=brightness_factor)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
else:
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
adjust_saturation_image_tensor = _FT.adjust_saturation
......@@ -28,9 +29,10 @@ adjust_saturation_image_pil = _FP.adjust_saturation
def adjust_saturation(inpt: DType, saturation_factor: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_saturation(saturation_factor=saturation_factor)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
else:
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
adjust_contrast_image_tensor = _FT.adjust_contrast
......@@ -40,9 +42,10 @@ adjust_contrast_image_pil = _FP.adjust_contrast
def adjust_contrast(inpt: DType, contrast_factor: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_contrast(contrast_factor=contrast_factor)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
else:
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
adjust_sharpness_image_tensor = _FT.adjust_sharpness
......@@ -52,9 +55,10 @@ adjust_sharpness_image_pil = _FP.adjust_sharpness
def adjust_sharpness(inpt: DType, sharpness_factor: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_sharpness(sharpness_factor=sharpness_factor)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
else:
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
adjust_hue_image_tensor = _FT.adjust_hue
......@@ -64,9 +68,10 @@ adjust_hue_image_pil = _FP.adjust_hue
def adjust_hue(inpt: DType, hue_factor: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_hue(hue_factor=hue_factor)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
else:
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
adjust_gamma_image_tensor = _FT.adjust_gamma
......@@ -76,9 +81,10 @@ adjust_gamma_image_pil = _FP.adjust_gamma
def adjust_gamma(inpt: DType, gamma: float, gain: float = 1) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_gamma(gamma=gamma, gain=gain)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
else:
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
posterize_image_tensor = _FT.posterize
......@@ -88,9 +94,10 @@ posterize_image_pil = _FP.posterize
def posterize(inpt: DType, bits: int) -> DType:
if isinstance(inpt, features._Feature):
return inpt.posterize(bits=bits)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return posterize_image_pil(inpt, bits=bits)
return posterize_image_tensor(inpt, bits=bits)
else:
return posterize_image_tensor(inpt, bits=bits)
solarize_image_tensor = _FT.solarize
......@@ -100,9 +107,10 @@ solarize_image_pil = _FP.solarize
def solarize(inpt: DType, threshold: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.solarize(threshold=threshold)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return solarize_image_pil(inpt, threshold=threshold)
return solarize_image_tensor(inpt, threshold=threshold)
else:
return solarize_image_tensor(inpt, threshold=threshold)
autocontrast_image_tensor = _FT.autocontrast
......@@ -112,9 +120,10 @@ autocontrast_image_pil = _FP.autocontrast
def autocontrast(inpt: DType) -> DType:
if isinstance(inpt, features._Feature):
return inpt.autocontrast()
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return autocontrast_image_pil(inpt)
return autocontrast_image_tensor(inpt)
else:
return autocontrast_image_tensor(inpt)
equalize_image_tensor = _FT.equalize
......@@ -124,9 +133,10 @@ equalize_image_pil = _FP.equalize
def equalize(inpt: DType) -> DType:
if isinstance(inpt, features._Feature):
return inpt.equalize()
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt)
return equalize_image_tensor(inpt)
else:
return equalize_image_tensor(inpt)
invert_image_tensor = _FT.invert
......@@ -136,6 +146,7 @@ invert_image_pil = _FP.invert
def invert(inpt: DType) -> DType:
if isinstance(inpt, features._Feature):
return inpt.invert()
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return invert_image_pil(inpt)
return invert_image_tensor(inpt)
else:
return invert_image_tensor(inpt)
......@@ -47,9 +47,10 @@ def horizontal_flip_bounding_box(
def horizontal_flip(inpt: DType) -> DType:
if isinstance(inpt, features._Feature):
return inpt.horizontal_flip()
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return horizontal_flip_image_pil(inpt)
return horizontal_flip_image_tensor(inpt)
else:
return horizontal_flip_image_tensor(inpt)
vertical_flip_image_tensor = _FT.vflip
......@@ -79,9 +80,10 @@ def vertical_flip_bounding_box(
def vertical_flip(inpt: DType) -> DType:
if isinstance(inpt, features._Feature):
return inpt.vertical_flip()
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return vertical_flip_image_pil(inpt)
return vertical_flip_image_tensor(inpt)
else:
return vertical_flip_image_tensor(inpt)
def resize_image_tensor(
......@@ -141,13 +143,13 @@ def resize(
if isinstance(inpt, features._Feature):
antialias = False if antialias is None else antialias
return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
if antialias is not None and not antialias:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size)
antialias = False if antialias is None else antialias
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
else:
antialias = False if antialias is None else antialias
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
def _affine_parse_args(
......@@ -210,18 +212,22 @@ def affine_image_tensor(
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
num_channels, height, width = img.shape[-3:]
extra_dims = img.shape[:-3]
img = img.view(-1, num_channels, height, width)
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
center_f = [0.0, 0.0]
if center is not None:
_, height, width = get_dimensions_image_tensor(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
return _FT.affine(img, matrix, interpolation=interpolation.value, fill=fill)
output = _FT.affine(img, matrix, interpolation=interpolation.value, fill=fill)
return output.view(extra_dims + (num_channels, height, width))
def affine_image_pil(
......@@ -231,7 +237,7 @@ def affine_image_pil(
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> PIL.Image.Image:
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
......@@ -344,7 +350,7 @@ def affine_bounding_box(
def affine_segmentation_mask(
img: torch.Tensor,
mask: torch.Tensor,
angle: float,
translate: List[float],
scale: float,
......@@ -352,7 +358,7 @@ def affine_segmentation_mask(
center: Optional[List[float]] = None,
) -> torch.Tensor:
return affine_image_tensor(
img,
mask,
angle=angle,
translate=translate,
scale=scale,
......@@ -362,6 +368,19 @@ def affine_segmentation_mask(
)
def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> Optional[List[float]]:
if fill is None:
fill = 0
# This cast does Sequence -> List[float] to please mypy and torch.jit.script
if not isinstance(fill, (int, float)):
fill = [float(v) for v in list(fill)]
else:
# It is OK to cast int to float as later we use inpt.dtype
fill = [float(fill)]
return fill
def affine(
inpt: DType,
angle: float,
......@@ -369,14 +388,14 @@ def affine(
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> DType:
if isinstance(inpt, features._Feature):
return inpt.affine(
angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center
)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return affine_image_pil(
inpt,
angle,
......@@ -387,16 +406,19 @@ def affine(
fill=fill,
center=center,
)
return affine_image_tensor(
inpt,
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
else:
fill = _convert_fill_arg(fill)
return affine_image_tensor(
inpt,
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
def rotate_image_tensor(
......@@ -407,6 +429,10 @@ def rotate_image_tensor(
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
num_channels, height, width = img.shape[-3:]
extra_dims = img.shape[:-3]
img = img.view(-1, num_channels, height, width)
center_f = [0.0, 0.0]
if center is not None:
if expand:
......@@ -419,7 +445,9 @@ def rotate_image_tensor(
# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
return _FT.rotate(img, matrix, interpolation=interpolation.value, expand=expand, fill=fill)
output = _FT.rotate(img, matrix, interpolation=interpolation.value, expand=expand, fill=fill)
new_height, new_width = output.shape[-2:]
return output.view(extra_dims + (num_channels, new_height, new_width))
def rotate_image_pil(
......@@ -427,7 +455,7 @@ def rotate_image_pil(
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> PIL.Image.Image:
if center is not None and expand:
......@@ -483,37 +511,40 @@ def rotate(
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> DType:
if isinstance(inpt, features._Feature):
return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
else:
fill = _convert_fill_arg(fill)
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
pad_image_pil = _FP.pad
def pad_image_tensor(
img: torch.Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant"
img: torch.Tensor, padding: Union[int, List[int]], fill: Union[int, float] = 0, padding_mode: str = "constant"
) -> torch.Tensor:
num_masks, height, width = img.shape[-3:]
num_channels, height, width = img.shape[-3:]
extra_dims = img.shape[:-3]
padded_image = _FT.pad(
img=img.view(-1, num_masks, height, width), padding=padding, fill=fill, padding_mode=padding_mode
img=img.view(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode
)
new_height, new_width = padded_image.shape[-2:]
return padded_image.view(extra_dims + (num_masks, new_height, new_width))
return padded_image.view(extra_dims + (num_channels, new_height, new_width))
# TODO: This should be removed once pytorch pad supports non-scalar padding values
def _pad_with_vector_fill(
img: torch.Tensor,
padding: List[int],
padding: Union[int, List[int]],
fill: Sequence[float] = [0.0],
padding_mode: str = "constant",
) -> torch.Tensor:
......@@ -521,7 +552,7 @@ def _pad_with_vector_fill(
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
output = pad_image_tensor(img, padding, fill=0, padding_mode="constant")
left, top, right, bottom = padding
left, top, right, bottom = _FT._parse_pad_padding(padding)
fill = torch.tensor(fill, dtype=img.dtype, device=img.device).view(-1, 1, 1)
if top > 0:
......@@ -536,7 +567,7 @@ def _pad_with_vector_fill(
def pad_segmentation_mask(
segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant"
segmentation_mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant"
) -> torch.Tensor:
num_masks, height, width = segmentation_mask.shape[-3:]
extra_dims = segmentation_mask.shape[:-3]
......@@ -550,7 +581,7 @@ def pad_segmentation_mask(
def pad_bounding_box(
bounding_box: torch.Tensor, padding: List[int], format: features.BoundingBoxFormat
bounding_box: torch.Tensor, padding: Union[int, List[int]], format: features.BoundingBoxFormat
) -> torch.Tensor:
left, _, top, _ = _FT._parse_pad_padding(padding)
......@@ -566,18 +597,27 @@ def pad_bounding_box(
def pad(
inpt: DType, padding: List[int], fill: Union[int, float, Sequence[float]] = 0.0, padding_mode: str = "constant"
inpt: DType,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant",
) -> DType:
if isinstance(inpt, features._Feature):
return inpt.pad(padding, fill=fill, padding_mode=padding_mode)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
else:
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
# TODO: PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
if isinstance(fill, (int, float)):
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
if fill is None:
fill = 0
return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode)
# TODO: PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
if isinstance(fill, (int, float)):
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode)
crop_image_tensor = _FT.crop
......@@ -610,9 +650,10 @@ def crop_segmentation_mask(img: torch.Tensor, top: int, left: int, height: int,
def crop(inpt: DType, top: int, left: int, height: int, width: int) -> DType:
if isinstance(inpt, features._Feature):
return inpt.crop(top, left, height, width)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return crop_image_pil(inpt, top, left, height, width)
return crop_image_tensor(inpt, top, left, height, width)
else:
return crop_image_tensor(inpt, top, left, height, width)
def perspective_image_tensor(
......@@ -628,7 +669,7 @@ def perspective_image_pil(
img: PIL.Image.Image,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BICUBIC,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> PIL.Image.Image:
return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
......@@ -726,21 +767,25 @@ def perspective(
inpt: DType,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> DType:
if isinstance(inpt, features._Feature):
return inpt.perspective(perspective_coeffs, interpolation=interpolation, fill=fill)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return perspective_image_pil(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
else:
fill = _convert_fill_arg(fill)
return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
if isinstance(output_size, numbers.Number):
return [int(output_size), int(output_size)]
if isinstance(output_size, (tuple, list)) and len(output_size) == 1:
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
return [output_size[0], output_size[0]]
return list(output_size)
else:
return list(output_size)
def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]:
......@@ -810,9 +855,10 @@ def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size:
def center_crop(inpt: DType, output_size: List[int]) -> DType:
if isinstance(inpt, features._Feature):
return inpt.center_crop(output_size)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return center_crop_image_pil(inpt, output_size)
return center_crop_image_tensor(inpt, output_size)
else:
return center_crop_image_tensor(inpt, output_size)
def resized_crop_image_tensor(
......@@ -880,12 +926,13 @@ def resized_crop(
if isinstance(inpt, features._Feature):
antialias = False if antialias is None else antialias
return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
if isinstance(inpt, PIL.Image.Image):
elif isinstance(inpt, PIL.Image.Image):
return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation)
antialias = False if antialias is None else antialias
return resized_crop_image_tensor(
inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
)
else:
antialias = False if antialias is None else antialias
return resized_crop_image_tensor(
inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
)
def _parse_five_crop_size(size: List[int]) -> List[int]:
......
import numbers
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
......@@ -286,7 +286,7 @@ def affine(
img: Image.Image,
matrix: List[float],
interpolation: int = _pil_constants.NEAREST,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
......@@ -304,7 +304,7 @@ def rotate(
interpolation: int = _pil_constants.NEAREST,
expand: bool = False,
center: Optional[Tuple[int, int]] = None,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
......@@ -319,7 +319,7 @@ def perspective(
img: Image.Image,
perspective_coeffs: List[float],
interpolation: int = _pil_constants.BICUBIC,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
......
......@@ -350,7 +350,7 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")
def _parse_pad_padding(padding: List[int]) -> List[int]:
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
if isinstance(padding, int):
if torch.jit.is_scripting():
# This maybe unreachable
......@@ -370,7 +370,9 @@ def _parse_pad_padding(padding: List[int]) -> List[int]:
return [pad_left, pad_right, pad_top, pad_bottom]
def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor:
def pad(
img: Tensor, padding: Union[int, List[int]], fill: Union[int, float] = 0, padding_mode: str = "constant"
) -> Tensor:
_assert_image_tensor(img)
if not isinstance(padding, (int, tuple, list)):
......@@ -383,8 +385,13 @@ def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mo
if isinstance(padding, tuple):
padding = list(padding)
if isinstance(padding, list) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
if isinstance(padding, list):
# TODO: Jit is failing on loading this op when scripted and saved
# https://github.com/pytorch/pytorch/issues/81100
if len(padding) not in [1, 2, 4]:
raise ValueError(
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment