Unverified Commit 9effc4cd authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Added some transformations and fixed type hints (#6245)

* Another attempt to add transforms

* Fixed padding type hint

* Fixed fill arg for pad and rotate, affine

* code formatting and type hints for affine transformation

* Fixed flake8

* Updated tests to save and load transforms

* Fixed code formatting issue

* Fixed jit loading issue

* Restored fill default value to None
Updated code according to the review

* Added tests for rotation, affine and zoom transforms

* Put back commented code

* Random erase bypass boxes and masks
Go back with if-return/elif-return/else-return

* Fixed acceptable and non-acceptable types for Cutmix/Mixup

* Updated conditions for _BaseMixupCutmix
parent e75a3337
...@@ -955,18 +955,7 @@ def test_adjust_gamma(device, dtype, config, channels): ...@@ -955,18 +955,7 @@ def test_adjust_gamma(device, dtype, config, channels):
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize( @pytest.mark.parametrize("pad", [2, [3], [0, 3], (3, 3), [4, 2, 4, 3]])
"pad",
[
2,
[
3,
],
[0, 3],
(3, 3),
[4, 2, 4, 3],
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"config", "config",
[ [
......
...@@ -3,7 +3,11 @@ import itertools ...@@ -3,7 +3,11 @@ import itertools
import pytest import pytest
import torch import torch
from common_utils import assert_equal from common_utils import assert_equal
from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels from test_prototype_transforms_functional import (
make_images,
make_bounding_boxes,
make_one_hot_labels,
)
from torchvision.prototype import transforms, features from torchvision.prototype import transforms, features
from torchvision.transforms.functional import to_pil_image, pil_to_tensor from torchvision.transforms.functional import to_pil_image, pil_to_tensor
...@@ -72,6 +76,9 @@ class TestSmoke: ...@@ -72,6 +76,9 @@ class TestSmoke:
transforms.ConvertImageDtype(), transforms.ConvertImageDtype(),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.Pad(5), transforms.Pad(5),
transforms.RandomZoomOut(),
transforms.RandomRotation(degrees=(-45, 45)),
transforms.RandomAffine(degrees=(-45, 45)),
) )
def test_common(self, transform, input): def test_common(self, transform, input):
transform(input) transform(input)
......
...@@ -317,7 +317,7 @@ def rotate_image_tensor(): ...@@ -317,7 +317,7 @@ def rotate_image_tensor():
[-87, 15, 90], # angle [-87, 15, 90], # angle
[True, False], # expand [True, False], # expand
[None, [12, 23]], # center [None, [12, 23]], # center
[None, [128]], # fill [None, [128], [12.0]], # fill
): ):
if center is not None and expand: if center is not None and expand:
# Skip warning: The provided center argument is ignored if expand is True # Skip warning: The provided center argument is ignored if expand is True
......
...@@ -128,13 +128,20 @@ class BoundingBox(_Feature): ...@@ -128,13 +128,20 @@ class BoundingBox(_Feature):
return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype) return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype)
def pad( def pad(
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant" self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant",
) -> BoundingBox: ) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F from torchvision.prototype.transforms import functional as _F
if padding_mode not in ["constant"]: if padding_mode not in ["constant"]:
raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes") raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
output = _F.pad_bounding_box(self, padding, format=self.format) output = _F.pad_bounding_box(self, padding, format=self.format)
# Update output image size: # Update output image size:
...@@ -153,7 +160,7 @@ class BoundingBox(_Feature): ...@@ -153,7 +160,7 @@ class BoundingBox(_Feature):
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> BoundingBox: ) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F from torchvision.prototype.transforms import functional as _F
...@@ -173,7 +180,7 @@ class BoundingBox(_Feature): ...@@ -173,7 +180,7 @@ class BoundingBox(_Feature):
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> BoundingBox: ) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F from torchvision.prototype.transforms import functional as _F
...@@ -194,18 +201,9 @@ class BoundingBox(_Feature): ...@@ -194,18 +201,9 @@ class BoundingBox(_Feature):
self, self,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> BoundingBox: ) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F from torchvision.prototype.transforms import functional as _F
output = _F.perspective_bounding_box(self, self.format, perspective_coeffs) output = _F.perspective_bounding_box(self, self.format, perspective_coeffs)
return BoundingBox.new_like(self, output, dtype=output.dtype) return BoundingBox.new_like(self, output, dtype=output.dtype)
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> BoundingBox:
raise TypeError("Erase transformation does not support bounding boxes")
def mixup(self, lam: float) -> BoundingBox:
raise TypeError("Mixup transformation does not support bounding boxes")
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> BoundingBox:
raise TypeError("Cutmix transformation does not support bounding boxes")
...@@ -120,7 +120,10 @@ class _Feature(torch.Tensor): ...@@ -120,7 +120,10 @@ class _Feature(torch.Tensor):
return self return self
def pad( def pad(
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant" self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant",
) -> Any: ) -> Any:
return self return self
...@@ -129,7 +132,7 @@ class _Feature(torch.Tensor): ...@@ -129,7 +132,7 @@ class _Feature(torch.Tensor):
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Any: ) -> Any:
return self return self
...@@ -141,7 +144,7 @@ class _Feature(torch.Tensor): ...@@ -141,7 +144,7 @@ class _Feature(torch.Tensor):
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Any: ) -> Any:
return self return self
...@@ -150,7 +153,7 @@ class _Feature(torch.Tensor): ...@@ -150,7 +153,7 @@ class _Feature(torch.Tensor):
self, self,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Any: ) -> Any:
return self return self
...@@ -186,12 +189,3 @@ class _Feature(torch.Tensor): ...@@ -186,12 +189,3 @@ class _Feature(torch.Tensor):
def invert(self) -> Any: def invert(self) -> Any:
return self return self
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> Any:
return self
def mixup(self, lam: float) -> Any:
return self
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> Any:
return self
...@@ -164,10 +164,20 @@ class Image(_Feature): ...@@ -164,10 +164,20 @@ class Image(_Feature):
return Image.new_like(self, output) return Image.new_like(self, output)
def pad( def pad(
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant" self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant",
) -> Image: ) -> Image:
from torchvision.prototype.transforms import functional as _F from torchvision.prototype.transforms import functional as _F
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
if fill is None:
fill = 0
# PyTorch's pad supports only scalars on fill. So we need to overwrite the colour # PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
if isinstance(fill, (int, float)): if isinstance(fill, (int, float)):
output = _F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) output = _F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
...@@ -183,10 +193,12 @@ class Image(_Feature): ...@@ -183,10 +193,12 @@ class Image(_Feature):
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Image: ) -> Image:
from torchvision.prototype.transforms import functional as _F from torchvision.prototype.transforms.functional import _geometry as _F
fill = _F._convert_fill_arg(fill)
output = _F.rotate_image_tensor( output = _F.rotate_image_tensor(
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
...@@ -200,10 +212,12 @@ class Image(_Feature): ...@@ -200,10 +212,12 @@ class Image(_Feature):
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Image: ) -> Image:
from torchvision.prototype.transforms import functional as _F from torchvision.prototype.transforms.functional import _geometry as _F
fill = _F._convert_fill_arg(fill)
output = _F.affine_image_tensor( output = _F.affine_image_tensor(
self, self,
...@@ -221,9 +235,11 @@ class Image(_Feature): ...@@ -221,9 +235,11 @@ class Image(_Feature):
self, self,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image: ) -> Image:
from torchvision.prototype.transforms import functional as _F from torchvision.prototype.transforms.functional import _geometry as _F
fill = _F._convert_fill_arg(fill)
output = _F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill) output = _F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill)
return Image.new_like(self, output) return Image.new_like(self, output)
...@@ -293,25 +309,3 @@ class Image(_Feature): ...@@ -293,25 +309,3 @@ class Image(_Feature):
output = _F.invert_image_tensor(self) output = _F.invert_image_tensor(self)
return Image.new_like(self, output) return Image.new_like(self, output)
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.erase_image_tensor(self, i, j, h, w, v)
return Image.new_like(self, output)
def mixup(self, lam: float) -> Image:
if self.ndim < 4:
raise ValueError("Need a batch of images")
output = self.clone()
output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam))
return Image.new_like(self, output)
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> Image:
if self.ndim < 4:
raise ValueError("Need a batch of images")
x1, y1, x2, y2 = box
image_rolled = self.roll(1, -4)
output = self.clone()
output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
return Image.new_like(self, output)
from __future__ import annotations from __future__ import annotations
from typing import Any, Optional, Sequence, cast, Union, Tuple from typing import Any, Optional, Sequence, cast, Union
import torch import torch
from torchvision.prototype.utils._internal import apply_recursively from torchvision.prototype.utils._internal import apply_recursively
...@@ -77,14 +77,3 @@ class OneHotLabel(_Feature): ...@@ -77,14 +77,3 @@ class OneHotLabel(_Feature):
return super().new_like( return super().new_like(
other, data, categories=categories if categories is not None else other.categories, **kwargs other, data, categories=categories if categories is not None else other.categories, **kwargs
) )
def mixup(self, lam: float) -> OneHotLabel:
if self.ndim < 2:
raise ValueError("Need a batch of one hot labels")
output = self.clone()
output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam))
return OneHotLabel.new_like(self, output)
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> OneHotLabel:
box # unused
return self.mixup(lam_adjusted)
from __future__ import annotations from __future__ import annotations
from typing import Tuple, List, Optional, Union, Sequence from typing import List, Optional, Union, Sequence
import torch
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from ._feature import _Feature from ._feature import _Feature
...@@ -61,10 +60,17 @@ class SegmentationMask(_Feature): ...@@ -61,10 +60,17 @@ class SegmentationMask(_Feature):
return SegmentationMask.new_like(self, output) return SegmentationMask.new_like(self, output)
def pad( def pad(
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant" self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant",
) -> SegmentationMask: ) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F from torchvision.prototype.transforms import functional as _F
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
output = _F.pad_segmentation_mask(self, padding, padding_mode=padding_mode) output = _F.pad_segmentation_mask(self, padding, padding_mode=padding_mode)
return SegmentationMask.new_like(self, output) return SegmentationMask.new_like(self, output)
...@@ -73,7 +79,7 @@ class SegmentationMask(_Feature): ...@@ -73,7 +79,7 @@ class SegmentationMask(_Feature):
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> SegmentationMask: ) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F from torchvision.prototype.transforms import functional as _F
...@@ -88,7 +94,7 @@ class SegmentationMask(_Feature): ...@@ -88,7 +94,7 @@ class SegmentationMask(_Feature):
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> SegmentationMask: ) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F from torchvision.prototype.transforms import functional as _F
...@@ -107,18 +113,9 @@ class SegmentationMask(_Feature): ...@@ -107,18 +113,9 @@ class SegmentationMask(_Feature):
self, self,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> SegmentationMask: ) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F from torchvision.prototype.transforms import functional as _F
output = _F.perspective_segmentation_mask(self, perspective_coeffs) output = _F.perspective_segmentation_mask(self, perspective_coeffs)
return SegmentationMask.new_like(self, output) return SegmentationMask.new_like(self, output)
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> SegmentationMask:
raise TypeError("Erase transformation does not support segmentation masks")
def mixup(self, lam: float) -> SegmentationMask:
raise TypeError("Mixup transformation does not support segmentation masks")
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> SegmentationMask:
raise TypeError("Cutmix transformation does not support segmentation masks")
...@@ -17,6 +17,8 @@ from ._geometry import ( ...@@ -17,6 +17,8 @@ from ._geometry import (
RandomVerticalFlip, RandomVerticalFlip,
Pad, Pad,
RandomZoomOut, RandomZoomOut,
RandomRotation,
RandomAffine,
) )
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda from ._misc import Identity, Normalize, ToDtype, Lambda
......
...@@ -9,7 +9,7 @@ from torchvision.prototype import features ...@@ -9,7 +9,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F from torchvision.prototype.transforms import Transform, functional as F
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import query_image, get_image_dimensions, has_all from ._utils import query_image, get_image_dimensions, has_any, has_all
class RandomErasing(_RandomApplyTransform): class RandomErasing(_RandomApplyTransform):
...@@ -86,13 +86,14 @@ class RandomErasing(_RandomApplyTransform): ...@@ -86,13 +86,14 @@ class RandomErasing(_RandomApplyTransform):
return dict(i=i, j=j, h=h, w=w, v=v) return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features._Feature): if isinstance(inpt, (features.Image, torch.Tensor)):
return inpt.erase(**params) output = F.erase_image_tensor(inpt, **params)
if isinstance(inpt, features.Image):
return features.Image.new_like(inpt, output)
return output
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
# TODO: We should implement a fallback to tensor, like gaussian_blur etc # TODO: We should implement a fallback to tensor, like gaussian_blur etc
raise RuntimeError("Not implemented") raise RuntimeError("Not implemented")
elif isinstance(inpt, torch.Tensor):
return F.erase_image_tensor(inpt, **params)
else: else:
return inpt return inpt
...@@ -107,16 +108,34 @@ class _BaseMixupCutmix(Transform): ...@@ -107,16 +108,34 @@ class _BaseMixupCutmix(Transform):
sample = inpts if len(inpts) > 1 else inpts[0] sample = inpts if len(inpts) > 1 else inpts[0]
if not has_all(sample, features.Image, features.OneHotLabel): if not has_all(sample, features.Image, features.OneHotLabel):
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label):
raise TypeError(
f"{type(self).__name__}() does not support bounding boxes, segmentation masks and plain labels."
)
return super().forward(sample) return super().forward(sample)
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
if inpt.ndim < 2:
raise ValueError("Need a batch of one hot labels")
output = inpt.clone()
output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam))
return features.OneHotLabel.new_like(inpt, output)
class RandomMixup(_BaseMixupCutmix): class RandomMixup(_BaseMixupCutmix):
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(()))) return dict(lam=float(self._dist.sample(())))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features._Feature): lam = params["lam"]
return inpt.mixup(**params) if isinstance(inpt, features.Image):
if inpt.ndim < 4:
raise ValueError("Need a batch of images")
output = inpt.clone()
output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam))
return features.Image.new_like(inpt, output)
elif isinstance(inpt, features.OneHotLabel):
return self._mixup_onehotlabel(inpt, lam)
else: else:
return inpt return inpt
...@@ -146,7 +165,17 @@ class RandomCutmix(_BaseMixupCutmix): ...@@ -146,7 +165,17 @@ class RandomCutmix(_BaseMixupCutmix):
return dict(box=box, lam_adjusted=lam_adjusted) return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features._Feature): if isinstance(inpt, features.Image):
return inpt.cutmix(**params) box = params["box"]
if inpt.ndim < 4:
raise ValueError("Need a batch of images")
x1, y1, x2, y2 = box
image_rolled = inpt.roll(1, -4)
output = inpt.clone()
output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
return features.Image.new_like(inpt, output)
elif isinstance(inpt, features.OneHotLabel):
lam_adjusted = params["lam_adjusted"]
return self._mixup_onehotlabel(inpt, lam_adjusted)
else: else:
return inpt return inpt
import math import math
from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar, Union, Type from typing import Any, Dict, Tuple, Optional, Callable, List, cast, Sequence, TypeVar, Union, Type
import PIL.Image import PIL.Image
import torch import torch
...@@ -9,7 +9,7 @@ from torchvision.prototype.utils._internal import query_recursively ...@@ -9,7 +9,7 @@ from torchvision.prototype.utils._internal import query_recursively
from torchvision.transforms.autoaugment import AutoAugmentPolicy from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms.functional import pil_to_tensor, to_pil_image, InterpolationMode from torchvision.transforms.functional import pil_to_tensor, to_pil_image, InterpolationMode
from ._utils import get_image_dimensions, is_simple_tensor from ._utils import get_image_dimensions
K = TypeVar("K") K = TypeVar("K")
V = TypeVar("V") V = TypeVar("V")
...@@ -29,7 +29,10 @@ def _put_into_sample(sample: Any, id: Tuple[Any, ...], item: Any) -> Any: ...@@ -29,7 +29,10 @@ def _put_into_sample(sample: Any, id: Tuple[Any, ...], item: Any) -> Any:
class _AutoAugmentBase(Transform): class _AutoAugmentBase(Transform):
def __init__( def __init__(
self, *, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None self,
*,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
) -> None: ) -> None:
super().__init__() super().__init__()
self.interpolation = interpolation self.interpolation = interpolation
...@@ -66,7 +69,7 @@ class _AutoAugmentBase(Transform): ...@@ -66,7 +69,7 @@ class _AutoAugmentBase(Transform):
def _parse_fill( def _parse_fill(
self, image: Union[PIL.Image.Image, torch.Tensor, features.Image], num_channels: int self, image: Union[PIL.Image.Image, torch.Tensor, features.Image], num_channels: int
) -> Optional[List[float]]: ) -> Union[int, float, Sequence[int], Sequence[float]]:
fill = self.fill fill = self.fill
if isinstance(image, PIL.Image.Image) or fill is None: if isinstance(image, PIL.Image.Image) or fill is None:
...@@ -79,36 +82,18 @@ class _AutoAugmentBase(Transform): ...@@ -79,36 +82,18 @@ class _AutoAugmentBase(Transform):
return fill return fill
def _dispatch_image_kernels(
self,
image_tensor_kernel: Callable,
image_pil_kernel: Callable,
input: Any,
*args: Any,
**kwargs: Any,
) -> Any:
if isinstance(input, features.Image):
output = image_tensor_kernel(input, *args, **kwargs)
return features.Image.new_like(input, output)
elif is_simple_tensor(input):
return image_tensor_kernel(input, *args, **kwargs)
else: # isinstance(input, PIL.Image.Image):
return image_pil_kernel(input, *args, **kwargs)
def _apply_image_transform( def _apply_image_transform(
self, self,
image: Any, image: Any,
transform_id: str, transform_id: str,
magnitude: float, magnitude: float,
interpolation: InterpolationMode, interpolation: InterpolationMode,
fill: Optional[List[float]], fill: Union[int, float, Sequence[int], Sequence[float]],
) -> Any: ) -> Any:
if transform_id == "Identity": if transform_id == "Identity":
return image return image
elif transform_id == "ShearX": elif transform_id == "ShearX":
return self._dispatch_image_kernels( return F.affine(
F.affine_image_tensor,
F.affine_image_pil,
image, image,
angle=0.0, angle=0.0,
translate=[0, 0], translate=[0, 0],
...@@ -118,9 +103,7 @@ class _AutoAugmentBase(Transform): ...@@ -118,9 +103,7 @@ class _AutoAugmentBase(Transform):
fill=fill, fill=fill,
) )
elif transform_id == "ShearY": elif transform_id == "ShearY":
return self._dispatch_image_kernels( return F.affine(
F.affine_image_tensor,
F.affine_image_pil,
image, image,
angle=0.0, angle=0.0,
translate=[0, 0], translate=[0, 0],
...@@ -130,9 +113,7 @@ class _AutoAugmentBase(Transform): ...@@ -130,9 +113,7 @@ class _AutoAugmentBase(Transform):
fill=fill, fill=fill,
) )
elif transform_id == "TranslateX": elif transform_id == "TranslateX":
return self._dispatch_image_kernels( return F.affine(
F.affine_image_tensor,
F.affine_image_pil,
image, image,
angle=0.0, angle=0.0,
translate=[int(magnitude), 0], translate=[int(magnitude), 0],
...@@ -142,9 +123,7 @@ class _AutoAugmentBase(Transform): ...@@ -142,9 +123,7 @@ class _AutoAugmentBase(Transform):
fill=fill, fill=fill,
) )
elif transform_id == "TranslateY": elif transform_id == "TranslateY":
return self._dispatch_image_kernels( return F.affine(
F.affine_image_tensor,
F.affine_image_pil,
image, image,
angle=0.0, angle=0.0,
translate=[0, int(magnitude)], translate=[0, int(magnitude)],
...@@ -154,46 +133,25 @@ class _AutoAugmentBase(Transform): ...@@ -154,46 +133,25 @@ class _AutoAugmentBase(Transform):
fill=fill, fill=fill,
) )
elif transform_id == "Rotate": elif transform_id == "Rotate":
return self._dispatch_image_kernels(F.rotate_image_tensor, F.rotate_image_pil, image, angle=magnitude) return F.rotate(image, angle=magnitude)
elif transform_id == "Brightness": elif transform_id == "Brightness":
return self._dispatch_image_kernels( return F.adjust_brightness(image, brightness_factor=1.0 + magnitude)
F.adjust_brightness_image_tensor,
F.adjust_brightness_image_pil,
image,
brightness_factor=1.0 + magnitude,
)
elif transform_id == "Color": elif transform_id == "Color":
return self._dispatch_image_kernels( return F.adjust_saturation(image, saturation_factor=1.0 + magnitude)
F.adjust_saturation_image_tensor,
F.adjust_saturation_image_pil,
image,
saturation_factor=1.0 + magnitude,
)
elif transform_id == "Contrast": elif transform_id == "Contrast":
return self._dispatch_image_kernels( return F.adjust_contrast(image, contrast_factor=1.0 + magnitude)
F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, image, contrast_factor=1.0 + magnitude
)
elif transform_id == "Sharpness": elif transform_id == "Sharpness":
return self._dispatch_image_kernels( return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude)
F.adjust_sharpness_image_tensor,
F.adjust_sharpness_image_pil,
image,
sharpness_factor=1.0 + magnitude,
)
elif transform_id == "Posterize": elif transform_id == "Posterize":
return self._dispatch_image_kernels( return F.posterize(image, bits=int(magnitude))
F.posterize_image_tensor, F.posterize_image_pil, image, bits=int(magnitude)
)
elif transform_id == "Solarize": elif transform_id == "Solarize":
return self._dispatch_image_kernels( return F.solarize(image, threshold=magnitude)
F.solarize_image_tensor, F.solarize_image_pil, image, threshold=magnitude
)
elif transform_id == "AutoContrast": elif transform_id == "AutoContrast":
return self._dispatch_image_kernels(F.autocontrast_image_tensor, F.autocontrast_image_pil, image) return F.autocontrast(image)
elif transform_id == "Equalize": elif transform_id == "Equalize":
return self._dispatch_image_kernels(F.equalize_image_tensor, F.equalize_image_pil, image) return F.equalize(image)
elif transform_id == "Invert": elif transform_id == "Invert":
return self._dispatch_image_kernels(F.invert_image_tensor, F.invert_image_pil, image) return F.invert(image)
else: else:
raise ValueError(f"No transform available for {transform_id}") raise ValueError(f"No transform available for {transform_id}")
...@@ -231,7 +189,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -231,7 +189,7 @@ class AutoAugment(_AutoAugmentBase):
self, self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None, fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy self.policy = policy
...@@ -393,7 +351,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -393,7 +351,7 @@ class RandAugment(_AutoAugmentBase):
magnitude: int = 9, magnitude: int = 9,
num_magnitude_bins: int = 31, num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None, fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops self.num_ops = num_ops
...@@ -453,7 +411,7 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -453,7 +411,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
*, *,
num_magnitude_bins: int = 31, num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None, fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
): ):
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
...@@ -512,7 +470,7 @@ class AugMix(_AutoAugmentBase): ...@@ -512,7 +470,7 @@ class AugMix(_AutoAugmentBase):
alpha: float = 1.0, alpha: float = 1.0,
all_ops: bool = True, all_ops: bool = True,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None, fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10 self._PARAMETER_MAX = 10
......
import collections.abc import collections.abc
import functools from typing import Any, Dict, Union, Tuple, Optional, Sequence, TypeVar
from typing import Any, Dict, Union, Tuple, Optional, Sequence, Callable, TypeVar
import PIL.Image import PIL.Image
import torch import torch
...@@ -53,76 +52,36 @@ class ColorJitter(Transform): ...@@ -53,76 +52,36 @@ class ColorJitter(Transform):
return None if value[0] == value[1] == center else (float(value[0]), float(value[1])) return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))
def _image_transform( @staticmethod
self, def _generate_value(left: float, right: float) -> float:
input: T, return float(torch.distributions.Uniform(left, right).sample())
*,
kernel_tensor: Callable[..., torch.Tensor],
kernel_pil: Callable[..., PIL.Image.Image],
**kwargs: Any,
) -> T:
if isinstance(input, features.Image):
output = kernel_tensor(input, **kwargs)
return features.Image.new_like(input, output)
elif is_simple_tensor(input):
return kernel_tensor(input, **kwargs)
elif isinstance(input, PIL.Image.Image):
return kernel_pil(input, **kwargs) # type: ignore[no-any-return]
else:
raise RuntimeError
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
image_transforms = [] fn_idx = torch.randperm(4)
if self.brightness is not None:
image_transforms.append( b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1])
functools.partial( c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1])
self._image_transform, s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1])
kernel_tensor=F.adjust_brightness_image_tensor, h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1])
kernel_pil=F.adjust_brightness_image_pil,
brightness_factor=float( return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h)
torch.distributions.Uniform(self.brightness[0], self.brightness[1]).sample()
), def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
) output = inpt
) brightness_factor = params["brightness_factor"]
if self.contrast is not None: contrast_factor = params["contrast_factor"]
image_transforms.append( saturation_factor = params["saturation_factor"]
functools.partial( hue_factor = params["hue_factor"]
self._image_transform, for fn_id in params["fn_idx"]:
kernel_tensor=F.adjust_contrast_image_tensor, if fn_id == 0 and brightness_factor is not None:
kernel_pil=F.adjust_contrast_image_pil, output = F.adjust_brightness(output, brightness_factor=brightness_factor)
contrast_factor=float(torch.distributions.Uniform(self.contrast[0], self.contrast[1]).sample()), elif fn_id == 1 and contrast_factor is not None:
) output = F.adjust_contrast(output, contrast_factor=contrast_factor)
) elif fn_id == 2 and saturation_factor is not None:
if self.saturation is not None: output = F.adjust_saturation(output, saturation_factor=saturation_factor)
image_transforms.append( elif fn_id == 3 and hue_factor is not None:
functools.partial( output = F.adjust_hue(output, hue_factor=hue_factor)
self._image_transform, return output
kernel_tensor=F.adjust_saturation_image_tensor,
kernel_pil=F.adjust_saturation_image_pil,
saturation_factor=float(
torch.distributions.Uniform(self.saturation[0], self.saturation[1]).sample()
),
)
)
if self.hue is not None:
image_transforms.append(
functools.partial(
self._image_transform,
kernel_tensor=F.adjust_hue_image_tensor,
kernel_pil=F.adjust_hue_image_pil,
hue_factor=float(torch.distributions.Uniform(self.hue[0], self.hue[1]).sample()),
)
)
return dict(image_transforms=[image_transforms[idx] for idx in torch.randperm(len(image_transforms))])
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(input, (features.Image, PIL.Image.Image)) or is_simple_tensor(input)):
return input
for transform in params["image_transforms"]:
input = transform(input)
return input
class _RandomChannelShuffle(Transform): class _RandomChannelShuffle(Transform):
...@@ -131,19 +90,19 @@ class _RandomChannelShuffle(Transform): ...@@ -131,19 +90,19 @@ class _RandomChannelShuffle(Transform):
num_channels, _, _ = get_image_dimensions(image) num_channels, _, _ = get_image_dimensions(image)
return dict(permutation=torch.randperm(num_channels)) return dict(permutation=torch.randperm(num_channels))
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(input, (features.Image, PIL.Image.Image)) or is_simple_tensor(input)): if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return input return inpt
image = input image = inpt
if isinstance(input, PIL.Image.Image): if isinstance(inpt, PIL.Image.Image):
image = _F.pil_to_tensor(image) image = _F.pil_to_tensor(image)
output = image[..., params["permutation"], :, :] output = image[..., params["permutation"], :, :]
if isinstance(input, features.Image): if isinstance(inpt, features.Image):
output = features.Image.new_like(input, output, color_space=features.ColorSpace.OTHER) output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER)
elif isinstance(input, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
output = _F.to_pil_image(output) output = _F.to_pil_image(output)
return output return output
...@@ -175,33 +134,25 @@ class RandomPhotometricDistort(Transform): ...@@ -175,33 +134,25 @@ class RandomPhotometricDistort(Transform):
contrast_before=torch.rand(()) < 0.5, contrast_before=torch.rand(()) < 0.5,
) )
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["brightness"]: if params["brightness"]:
input = self._brightness(input) inpt = self._brightness(inpt)
if params["contrast1"] and params["contrast_before"]: if params["contrast1"] and params["contrast_before"]:
input = self._contrast(input) inpt = self._contrast(inpt)
if params["saturation"]: if params["saturation"]:
input = self._saturation(input) inpt = self._saturation(inpt)
if params["saturation"]: if params["saturation"]:
input = self._saturation(input) inpt = self._saturation(inpt)
if params["contrast2"] and not params["contrast_before"]: if params["contrast2"] and not params["contrast_before"]:
input = self._contrast(input) inpt = self._contrast(inpt)
if params["channel_shuffle"]: if params["channel_shuffle"]:
input = self._channel_shuffle(input) inpt = self._channel_shuffle(inpt)
return input return inpt
class RandomEqualize(_RandomApplyTransform): class RandomEqualize(_RandomApplyTransform):
def __init__(self, p: float = 0.5): def __init__(self, p: float = 0.5):
super().__init__(p=p) super().__init__(p=p)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image): return F.equalize(inpt)
output = F.equalize_image_tensor(input)
return features.Image.new_like(input, output)
elif is_simple_tensor(input):
return F.equalize_image_tensor(input)
elif isinstance(input, PIL.Image.Image):
return F.equalize_image_pil(input)
else:
return input
...@@ -2,14 +2,14 @@ import collections.abc ...@@ -2,14 +2,14 @@ import collections.abc
import math import math
import numbers import numbers
import warnings import warnings
from typing import Any, Dict, List, Union, Sequence, Tuple, cast from typing import Any, Dict, List, Optional, Union, Sequence, Tuple, cast
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F from torchvision.prototype.transforms import Transform, functional as F
from torchvision.transforms.functional import pil_to_tensor, InterpolationMode from torchvision.transforms.functional import pil_to_tensor, InterpolationMode
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int from torchvision.transforms.transforms import _setup_size, _setup_angle, _check_sequence_input
from typing_extensions import Literal from typing_extensions import Literal
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
...@@ -17,41 +17,13 @@ from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor ...@@ -17,41 +17,13 @@ from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor
class RandomHorizontalFlip(_RandomApplyTransform): class RandomHorizontalFlip(_RandomApplyTransform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image): return F.horizontal_flip(inpt)
output = F.horizontal_flip_image_tensor(input)
return features.Image.new_like(input, output)
elif isinstance(input, features.SegmentationMask):
output = F.horizontal_flip_segmentation_mask(input)
return features.SegmentationMask.new_like(input, output)
elif isinstance(input, features.BoundingBox):
output = F.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size)
return features.BoundingBox.new_like(input, output)
elif isinstance(input, PIL.Image.Image):
return F.horizontal_flip_image_pil(input)
elif is_simple_tensor(input):
return F.horizontal_flip_image_tensor(input)
else:
return input
class RandomVerticalFlip(_RandomApplyTransform): class RandomVerticalFlip(_RandomApplyTransform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image): return F.vertical_flip(inpt)
output = F.vertical_flip_image_tensor(input)
return features.Image.new_like(input, output)
elif isinstance(input, features.SegmentationMask):
output = F.vertical_flip_segmentation_mask(input)
return features.SegmentationMask.new_like(input, output)
elif isinstance(input, features.BoundingBox):
output = F.vertical_flip_bounding_box(input, format=input.format, image_size=input.image_size)
return features.BoundingBox.new_like(input, output)
elif isinstance(input, PIL.Image.Image):
return F.vertical_flip_image_pil(input)
elif is_simple_tensor(input):
return F.vertical_flip_image_tensor(input)
else:
return input
class Resize(Transform): class Resize(Transform):
...@@ -59,27 +31,23 @@ class Resize(Transform): ...@@ -59,27 +31,23 @@ class Resize(Transform):
self, self,
size: Union[int, Sequence[int]], size: Union[int, Sequence[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.size = [size] if isinstance(size, int) else list(size) self.size = [size] if isinstance(size, int) else list(size)
self.interpolation = interpolation self.interpolation = interpolation
self.max_size = max_size
self.antialias = antialias
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image): return F.resize(
output = F.resize_image_tensor(input, self.size, interpolation=self.interpolation) inpt,
return features.Image.new_like(input, output) self.size,
elif isinstance(input, features.SegmentationMask): interpolation=self.interpolation,
output = F.resize_segmentation_mask(input, self.size) max_size=self.max_size,
return features.SegmentationMask.new_like(input, output) antialias=self.antialias,
elif isinstance(input, features.BoundingBox): )
output = F.resize_bounding_box(input, self.size, image_size=input.image_size)
return features.BoundingBox.new_like(input, output, image_size=cast(Tuple[int, int], tuple(self.size)))
elif isinstance(input, PIL.Image.Image):
return F.resize_image_pil(input, self.size, interpolation=self.interpolation)
elif is_simple_tensor(input):
return F.resize_image_tensor(input, self.size, interpolation=self.interpolation)
else:
return input
class CenterCrop(Transform): class CenterCrop(Transform):
...@@ -87,22 +55,8 @@ class CenterCrop(Transform): ...@@ -87,22 +55,8 @@ class CenterCrop(Transform):
super().__init__() super().__init__()
self.output_size = output_size self.output_size = output_size
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image): return F.center_crop(inpt, output_size=self.output_size)
output = F.center_crop_image_tensor(input, self.output_size)
return features.Image.new_like(input, output)
elif is_simple_tensor(input):
return F.center_crop_image_tensor(input, self.output_size)
elif isinstance(input, PIL.Image.Image):
return F.center_crop_image_pil(input, self.output_size)
else:
return input
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)
class RandomResizedCrop(Transform): class RandomResizedCrop(Transform):
...@@ -112,6 +66,7 @@ class RandomResizedCrop(Transform): ...@@ -112,6 +66,7 @@ class RandomResizedCrop(Transform):
scale: Tuple[float, float] = (0.08, 1.0), scale: Tuple[float, float] = (0.08, 1.0),
ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0), ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
...@@ -125,20 +80,16 @@ class RandomResizedCrop(Transform): ...@@ -125,20 +80,16 @@ class RandomResizedCrop(Transform):
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)") warnings.warn("Scale and ratio should be of kind (min, max)")
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
self.size = size self.size = size
self.scale = scale self.scale = scale
self.ratio = ratio self.ratio = ratio
self.interpolation = interpolation self.interpolation = interpolation
self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
# vfdev-5: techically, this op can work on bboxes/segm masks only inputs without image in samples
# What if we have multiple images/bboxes/masks of different sizes ?
# TODO: let's support bbox or mask in samples without image
image = query_image(sample) image = query_image(sample)
_, height, width = get_image_dimensions(image) _, height, width = get_image_dimensions(image)
area = height * width area = height * width
...@@ -177,24 +128,10 @@ class RandomResizedCrop(Transform): ...@@ -177,24 +128,10 @@ class RandomResizedCrop(Transform):
return dict(top=i, left=j, height=h, width=w) return dict(top=i, left=j, height=h, width=w)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image): return F.resized_crop(
output = F.resized_crop_image_tensor( inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias
input, **params, size=list(self.size), interpolation=self.interpolation )
)
return features.Image.new_like(input, output)
elif is_simple_tensor(input):
return F.resized_crop_image_tensor(input, **params, size=list(self.size), interpolation=self.interpolation)
elif isinstance(input, PIL.Image.Image):
return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation)
else:
return input
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)
class MultiCropResult(list): class MultiCropResult(list):
...@@ -283,19 +220,23 @@ class BatchMultiCrop(Transform): ...@@ -283,19 +220,23 @@ class BatchMultiCrop(Transform):
return apply_recursively(inputs if len(inputs) > 1 else inputs[0]) return apply_recursively(inputs if len(inputs) > 1 else inputs[0])
def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> None:
if not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
class Pad(Transform): class Pad(Transform):
def __init__( def __init__(
self, self,
padding: Union[int, Sequence[int]], padding: Union[int, Sequence[int]],
fill: Union[float, Sequence[float]] = 0.0, fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
if not isinstance(padding, (numbers.Number, tuple, list)): if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg") raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (numbers.Number, str, tuple, list)): _check_fill_arg(fill)
raise TypeError("Got inappropriate fill arg")
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
...@@ -309,55 +250,20 @@ class Pad(Transform): ...@@ -309,55 +250,20 @@ class Pad(Transform):
self.fill = fill self.fill = fill
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image) or is_simple_tensor(input): return F.pad(inpt, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode)
# PyTorch's pad supports only integers on fill. So we need to overwrite the colour
output = F.pad_image_tensor(input, params["padding"], fill=0, padding_mode="constant")
left, top, right, bottom = params["padding"]
fill = torch.tensor(params["fill"], dtype=input.dtype, device=input.device).to().view(-1, 1, 1)
if top > 0:
output[..., :top, :] = fill
if left > 0:
output[..., :, :left] = fill
if bottom > 0:
output[..., -bottom:, :] = fill
if right > 0:
output[..., :, -right:] = fill
if isinstance(input, features.Image):
output = features.Image.new_like(input, output)
return output
elif isinstance(input, PIL.Image.Image):
return F.pad_image_pil(
input,
params["padding"],
fill=tuple(int(v) if input.mode != "F" else v for v in params["fill"]),
padding_mode="constant",
)
elif isinstance(input, features.BoundingBox):
output = F.pad_bounding_box(input, params["padding"], format=input.format)
left, top, right, bottom = params["padding"]
height, width = input.image_size
height += top + bottom
width += left + right
return features.BoundingBox.new_like(input, output, image_size=(height, width))
else:
return input
class RandomZoomOut(_RandomApplyTransform): class RandomZoomOut(_RandomApplyTransform):
def __init__( def __init__(
self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5 self,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
side_range: Tuple[float, float] = (1.0, 4.0),
p: float = 0.5,
) -> None: ) -> None:
super().__init__(p=p) super().__init__(p=p)
if fill is None: _check_fill_arg(fill)
fill = 0.0
self.fill = fill self.fill = fill
self.side_range = side_range self.side_range = side_range
...@@ -385,6 +291,126 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -385,6 +291,126 @@ class RandomZoomOut(_RandomApplyTransform):
return dict(padding=padding, fill=fill) return dict(padding=padding, fill=fill)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
transform = Pad(**params, padding_mode="constant") return F.pad(inpt, **params)
return transform(input)
class RandomRotation(Transform):
def __init__(
self,
degrees: Union[numbers.Number, Sequence],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
self.interpolation = interpolation
self.expand = expand
_check_fill_arg(fill)
self.fill = fill
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
self.center = center
def _get_params(self, sample: Any) -> Dict[str, Any]:
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
return dict(angle=angle)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.rotate(
inpt,
**params,
interpolation=self.interpolation,
expand=self.expand,
fill=self.fill,
center=self.center,
)
class RandomAffine(Transform):
def __init__(
self,
degrees: Union[numbers.Number, Sequence],
translate: Optional[Sequence[float]] = None,
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[float, Sequence[float]]] = None,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
if translate is not None:
_check_sequence_input(translate, "translate", req_sizes=(2,))
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
_check_sequence_input(scale, "scale", req_sizes=(2,))
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
else:
self.shear = shear
self.interpolation = interpolation
_check_fill_arg(fill)
self.fill = fill
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
self.center = center
def _get_params(self, sample: Any) -> Dict[str, Any]:
# Get image size
# TODO: make it work with bboxes and segm masks
image = query_image(sample)
_, height, width = get_image_dimensions(image)
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
if self.translate is not None:
max_dx = float(self.translate[0] * width)
max_dy = float(self.translate[1] * height)
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
translations = (tx, ty)
else:
translations = (0, 0)
if self.scale is not None:
scale = float(torch.empty(1).uniform_(self.scale[0], self.scale[1]).item())
else:
scale = 1.0
shear_x = shear_y = 0.0
if self.shear is not None:
shear_x = float(torch.empty(1).uniform_(self.shear[0], self.shear[1]).item())
if len(self.shear) == 4:
shear_y = float(torch.empty(1).uniform_(self.shear[2], self.shear[3]).item())
shear = (shear_x, shear_y)
return dict(angle=angle, translations=translations, scale=scale, shear=shear)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.affine(
inpt,
**params,
interpolation=self.interpolation,
fill=self.fill,
center=self.center,
)
...@@ -16,9 +16,10 @@ adjust_brightness_image_pil = _FP.adjust_brightness ...@@ -16,9 +16,10 @@ adjust_brightness_image_pil = _FP.adjust_brightness
def adjust_brightness(inpt: DType, brightness_factor: float) -> DType: def adjust_brightness(inpt: DType, brightness_factor: float) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.adjust_brightness(brightness_factor=brightness_factor) return inpt.adjust_brightness(brightness_factor=brightness_factor)
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) else:
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
adjust_saturation_image_tensor = _FT.adjust_saturation adjust_saturation_image_tensor = _FT.adjust_saturation
...@@ -28,9 +29,10 @@ adjust_saturation_image_pil = _FP.adjust_saturation ...@@ -28,9 +29,10 @@ adjust_saturation_image_pil = _FP.adjust_saturation
def adjust_saturation(inpt: DType, saturation_factor: float) -> DType: def adjust_saturation(inpt: DType, saturation_factor: float) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.adjust_saturation(saturation_factor=saturation_factor) return inpt.adjust_saturation(saturation_factor=saturation_factor)
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) else:
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
adjust_contrast_image_tensor = _FT.adjust_contrast adjust_contrast_image_tensor = _FT.adjust_contrast
...@@ -40,9 +42,10 @@ adjust_contrast_image_pil = _FP.adjust_contrast ...@@ -40,9 +42,10 @@ adjust_contrast_image_pil = _FP.adjust_contrast
def adjust_contrast(inpt: DType, contrast_factor: float) -> DType: def adjust_contrast(inpt: DType, contrast_factor: float) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.adjust_contrast(contrast_factor=contrast_factor) return inpt.adjust_contrast(contrast_factor=contrast_factor)
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor) return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) else:
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
adjust_sharpness_image_tensor = _FT.adjust_sharpness adjust_sharpness_image_tensor = _FT.adjust_sharpness
...@@ -52,9 +55,10 @@ adjust_sharpness_image_pil = _FP.adjust_sharpness ...@@ -52,9 +55,10 @@ adjust_sharpness_image_pil = _FP.adjust_sharpness
def adjust_sharpness(inpt: DType, sharpness_factor: float) -> DType: def adjust_sharpness(inpt: DType, sharpness_factor: float) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.adjust_sharpness(sharpness_factor=sharpness_factor) return inpt.adjust_sharpness(sharpness_factor=sharpness_factor)
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) else:
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
adjust_hue_image_tensor = _FT.adjust_hue adjust_hue_image_tensor = _FT.adjust_hue
...@@ -64,9 +68,10 @@ adjust_hue_image_pil = _FP.adjust_hue ...@@ -64,9 +68,10 @@ adjust_hue_image_pil = _FP.adjust_hue
def adjust_hue(inpt: DType, hue_factor: float) -> DType: def adjust_hue(inpt: DType, hue_factor: float) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.adjust_hue(hue_factor=hue_factor) return inpt.adjust_hue(hue_factor=hue_factor)
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_hue_image_pil(inpt, hue_factor=hue_factor) return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) else:
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
adjust_gamma_image_tensor = _FT.adjust_gamma adjust_gamma_image_tensor = _FT.adjust_gamma
...@@ -76,9 +81,10 @@ adjust_gamma_image_pil = _FP.adjust_gamma ...@@ -76,9 +81,10 @@ adjust_gamma_image_pil = _FP.adjust_gamma
def adjust_gamma(inpt: DType, gamma: float, gain: float = 1) -> DType: def adjust_gamma(inpt: DType, gamma: float, gain: float = 1) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.adjust_gamma(gamma=gamma, gain=gain) return inpt.adjust_gamma(gamma=gamma, gain=gain)
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) else:
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
posterize_image_tensor = _FT.posterize posterize_image_tensor = _FT.posterize
...@@ -88,9 +94,10 @@ posterize_image_pil = _FP.posterize ...@@ -88,9 +94,10 @@ posterize_image_pil = _FP.posterize
def posterize(inpt: DType, bits: int) -> DType: def posterize(inpt: DType, bits: int) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.posterize(bits=bits) return inpt.posterize(bits=bits)
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return posterize_image_pil(inpt, bits=bits) return posterize_image_pil(inpt, bits=bits)
return posterize_image_tensor(inpt, bits=bits) else:
return posterize_image_tensor(inpt, bits=bits)
solarize_image_tensor = _FT.solarize solarize_image_tensor = _FT.solarize
...@@ -100,9 +107,10 @@ solarize_image_pil = _FP.solarize ...@@ -100,9 +107,10 @@ solarize_image_pil = _FP.solarize
def solarize(inpt: DType, threshold: float) -> DType: def solarize(inpt: DType, threshold: float) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.solarize(threshold=threshold) return inpt.solarize(threshold=threshold)
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return solarize_image_pil(inpt, threshold=threshold) return solarize_image_pil(inpt, threshold=threshold)
return solarize_image_tensor(inpt, threshold=threshold) else:
return solarize_image_tensor(inpt, threshold=threshold)
autocontrast_image_tensor = _FT.autocontrast autocontrast_image_tensor = _FT.autocontrast
...@@ -112,9 +120,10 @@ autocontrast_image_pil = _FP.autocontrast ...@@ -112,9 +120,10 @@ autocontrast_image_pil = _FP.autocontrast
def autocontrast(inpt: DType) -> DType: def autocontrast(inpt: DType) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.autocontrast() return inpt.autocontrast()
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return autocontrast_image_pil(inpt) return autocontrast_image_pil(inpt)
return autocontrast_image_tensor(inpt) else:
return autocontrast_image_tensor(inpt)
equalize_image_tensor = _FT.equalize equalize_image_tensor = _FT.equalize
...@@ -124,9 +133,10 @@ equalize_image_pil = _FP.equalize ...@@ -124,9 +133,10 @@ equalize_image_pil = _FP.equalize
def equalize(inpt: DType) -> DType: def equalize(inpt: DType) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.equalize() return inpt.equalize()
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt) return equalize_image_pil(inpt)
return equalize_image_tensor(inpt) else:
return equalize_image_tensor(inpt)
invert_image_tensor = _FT.invert invert_image_tensor = _FT.invert
...@@ -136,6 +146,7 @@ invert_image_pil = _FP.invert ...@@ -136,6 +146,7 @@ invert_image_pil = _FP.invert
def invert(inpt: DType) -> DType: def invert(inpt: DType) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.invert() return inpt.invert()
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return invert_image_pil(inpt) return invert_image_pil(inpt)
return invert_image_tensor(inpt) else:
return invert_image_tensor(inpt)
...@@ -47,9 +47,10 @@ def horizontal_flip_bounding_box( ...@@ -47,9 +47,10 @@ def horizontal_flip_bounding_box(
def horizontal_flip(inpt: DType) -> DType: def horizontal_flip(inpt: DType) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.horizontal_flip() return inpt.horizontal_flip()
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return horizontal_flip_image_pil(inpt) return horizontal_flip_image_pil(inpt)
return horizontal_flip_image_tensor(inpt) else:
return horizontal_flip_image_tensor(inpt)
vertical_flip_image_tensor = _FT.vflip vertical_flip_image_tensor = _FT.vflip
...@@ -79,9 +80,10 @@ def vertical_flip_bounding_box( ...@@ -79,9 +80,10 @@ def vertical_flip_bounding_box(
def vertical_flip(inpt: DType) -> DType: def vertical_flip(inpt: DType) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.vertical_flip() return inpt.vertical_flip()
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return vertical_flip_image_pil(inpt) return vertical_flip_image_pil(inpt)
return vertical_flip_image_tensor(inpt) else:
return vertical_flip_image_tensor(inpt)
def resize_image_tensor( def resize_image_tensor(
...@@ -141,13 +143,13 @@ def resize( ...@@ -141,13 +143,13 @@ def resize(
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
antialias = False if antialias is None else antialias antialias = False if antialias is None else antialias
return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias) return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
if antialias is not None and not antialias: if antialias is not None and not antialias:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size) return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size)
else:
antialias = False if antialias is None else antialias antialias = False if antialias is None else antialias
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
def _affine_parse_args( def _affine_parse_args(
...@@ -210,18 +212,22 @@ def affine_image_tensor( ...@@ -210,18 +212,22 @@ def affine_image_tensor(
fill: Optional[List[float]] = None, fill: Optional[List[float]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
num_channels, height, width = img.shape[-3:]
extra_dims = img.shape[:-3]
img = img.view(-1, num_channels, height, width)
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
center_f = [0.0, 0.0] center_f = [0.0, 0.0]
if center is not None: if center is not None:
_, height, width = get_dimensions_image_tensor(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])] center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
translate_f = [1.0 * t for t in translate] translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
return _FT.affine(img, matrix, interpolation=interpolation.value, fill=fill) output = _FT.affine(img, matrix, interpolation=interpolation.value, fill=fill)
return output.view(extra_dims + (num_channels, height, width))
def affine_image_pil( def affine_image_pil(
...@@ -231,7 +237,7 @@ def affine_image_pil( ...@@ -231,7 +237,7 @@ def affine_image_pil(
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
...@@ -344,7 +350,7 @@ def affine_bounding_box( ...@@ -344,7 +350,7 @@ def affine_bounding_box(
def affine_segmentation_mask( def affine_segmentation_mask(
img: torch.Tensor, mask: torch.Tensor,
angle: float, angle: float,
translate: List[float], translate: List[float],
scale: float, scale: float,
...@@ -352,7 +358,7 @@ def affine_segmentation_mask( ...@@ -352,7 +358,7 @@ def affine_segmentation_mask(
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return affine_image_tensor( return affine_image_tensor(
img, mask,
angle=angle, angle=angle,
translate=translate, translate=translate,
scale=scale, scale=scale,
...@@ -362,6 +368,19 @@ def affine_segmentation_mask( ...@@ -362,6 +368,19 @@ def affine_segmentation_mask(
) )
def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> Optional[List[float]]:
if fill is None:
fill = 0
# This cast does Sequence -> List[float] to please mypy and torch.jit.script
if not isinstance(fill, (int, float)):
fill = [float(v) for v in list(fill)]
else:
# It is OK to cast int to float as later we use inpt.dtype
fill = [float(fill)]
return fill
def affine( def affine(
inpt: DType, inpt: DType,
angle: float, angle: float,
...@@ -369,14 +388,14 @@ def affine( ...@@ -369,14 +388,14 @@ def affine(
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> DType: ) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.affine( return inpt.affine(
angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center
) )
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return affine_image_pil( return affine_image_pil(
inpt, inpt,
angle, angle,
...@@ -387,16 +406,19 @@ def affine( ...@@ -387,16 +406,19 @@ def affine(
fill=fill, fill=fill,
center=center, center=center,
) )
return affine_image_tensor( else:
inpt, fill = _convert_fill_arg(fill)
angle,
translate=translate, return affine_image_tensor(
scale=scale, inpt,
shear=shear, angle,
interpolation=interpolation, translate=translate,
fill=fill, scale=scale,
center=center, shear=shear,
) interpolation=interpolation,
fill=fill,
center=center,
)
def rotate_image_tensor( def rotate_image_tensor(
...@@ -407,6 +429,10 @@ def rotate_image_tensor( ...@@ -407,6 +429,10 @@ def rotate_image_tensor(
fill: Optional[List[float]] = None, fill: Optional[List[float]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
num_channels, height, width = img.shape[-3:]
extra_dims = img.shape[:-3]
img = img.view(-1, num_channels, height, width)
center_f = [0.0, 0.0] center_f = [0.0, 0.0]
if center is not None: if center is not None:
if expand: if expand:
...@@ -419,7 +445,9 @@ def rotate_image_tensor( ...@@ -419,7 +445,9 @@ def rotate_image_tensor(
# due to current incoherence of rotation angle direction between affine and rotate implementations # due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle. # we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
return _FT.rotate(img, matrix, interpolation=interpolation.value, expand=expand, fill=fill) output = _FT.rotate(img, matrix, interpolation=interpolation.value, expand=expand, fill=fill)
new_height, new_width = output.shape[-2:]
return output.view(extra_dims + (num_channels, new_height, new_width))
def rotate_image_pil( def rotate_image_pil(
...@@ -427,7 +455,7 @@ def rotate_image_pil( ...@@ -427,7 +455,7 @@ def rotate_image_pil(
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
if center is not None and expand: if center is not None and expand:
...@@ -483,37 +511,40 @@ def rotate( ...@@ -483,37 +511,40 @@ def rotate(
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> DType: ) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) else:
fill = _convert_fill_arg(fill)
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
pad_image_pil = _FP.pad pad_image_pil = _FP.pad
def pad_image_tensor( def pad_image_tensor(
img: torch.Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant" img: torch.Tensor, padding: Union[int, List[int]], fill: Union[int, float] = 0, padding_mode: str = "constant"
) -> torch.Tensor: ) -> torch.Tensor:
num_masks, height, width = img.shape[-3:] num_channels, height, width = img.shape[-3:]
extra_dims = img.shape[:-3] extra_dims = img.shape[:-3]
padded_image = _FT.pad( padded_image = _FT.pad(
img=img.view(-1, num_masks, height, width), padding=padding, fill=fill, padding_mode=padding_mode img=img.view(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode
) )
new_height, new_width = padded_image.shape[-2:] new_height, new_width = padded_image.shape[-2:]
return padded_image.view(extra_dims + (num_masks, new_height, new_width)) return padded_image.view(extra_dims + (num_channels, new_height, new_width))
# TODO: This should be removed once pytorch pad supports non-scalar padding values # TODO: This should be removed once pytorch pad supports non-scalar padding values
def _pad_with_vector_fill( def _pad_with_vector_fill(
img: torch.Tensor, img: torch.Tensor,
padding: List[int], padding: Union[int, List[int]],
fill: Sequence[float] = [0.0], fill: Sequence[float] = [0.0],
padding_mode: str = "constant", padding_mode: str = "constant",
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -521,7 +552,7 @@ def _pad_with_vector_fill( ...@@ -521,7 +552,7 @@ def _pad_with_vector_fill(
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar") raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
output = pad_image_tensor(img, padding, fill=0, padding_mode="constant") output = pad_image_tensor(img, padding, fill=0, padding_mode="constant")
left, top, right, bottom = padding left, top, right, bottom = _FT._parse_pad_padding(padding)
fill = torch.tensor(fill, dtype=img.dtype, device=img.device).view(-1, 1, 1) fill = torch.tensor(fill, dtype=img.dtype, device=img.device).view(-1, 1, 1)
if top > 0: if top > 0:
...@@ -536,7 +567,7 @@ def _pad_with_vector_fill( ...@@ -536,7 +567,7 @@ def _pad_with_vector_fill(
def pad_segmentation_mask( def pad_segmentation_mask(
segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant" segmentation_mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant"
) -> torch.Tensor: ) -> torch.Tensor:
num_masks, height, width = segmentation_mask.shape[-3:] num_masks, height, width = segmentation_mask.shape[-3:]
extra_dims = segmentation_mask.shape[:-3] extra_dims = segmentation_mask.shape[:-3]
...@@ -550,7 +581,7 @@ def pad_segmentation_mask( ...@@ -550,7 +581,7 @@ def pad_segmentation_mask(
def pad_bounding_box( def pad_bounding_box(
bounding_box: torch.Tensor, padding: List[int], format: features.BoundingBoxFormat bounding_box: torch.Tensor, padding: Union[int, List[int]], format: features.BoundingBoxFormat
) -> torch.Tensor: ) -> torch.Tensor:
left, _, top, _ = _FT._parse_pad_padding(padding) left, _, top, _ = _FT._parse_pad_padding(padding)
...@@ -566,18 +597,27 @@ def pad_bounding_box( ...@@ -566,18 +597,27 @@ def pad_bounding_box(
def pad( def pad(
inpt: DType, padding: List[int], fill: Union[int, float, Sequence[float]] = 0.0, padding_mode: str = "constant" inpt: DType,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant",
) -> DType: ) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.pad(padding, fill=fill, padding_mode=padding_mode) return inpt.pad(padding, fill=fill, padding_mode=padding_mode)
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode) return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
else:
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
# TODO: PyTorch's pad supports only scalars on fill. So we need to overwrite the colour if fill is None:
if isinstance(fill, (int, float)): fill = 0
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode) # TODO: PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
if isinstance(fill, (int, float)):
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode)
crop_image_tensor = _FT.crop crop_image_tensor = _FT.crop
...@@ -610,9 +650,10 @@ def crop_segmentation_mask(img: torch.Tensor, top: int, left: int, height: int, ...@@ -610,9 +650,10 @@ def crop_segmentation_mask(img: torch.Tensor, top: int, left: int, height: int,
def crop(inpt: DType, top: int, left: int, height: int, width: int) -> DType: def crop(inpt: DType, top: int, left: int, height: int, width: int) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.crop(top, left, height, width) return inpt.crop(top, left, height, width)
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return crop_image_pil(inpt, top, left, height, width) return crop_image_pil(inpt, top, left, height, width)
return crop_image_tensor(inpt, top, left, height, width) else:
return crop_image_tensor(inpt, top, left, height, width)
def perspective_image_tensor( def perspective_image_tensor(
...@@ -628,7 +669,7 @@ def perspective_image_pil( ...@@ -628,7 +669,7 @@ def perspective_image_pil(
img: PIL.Image.Image, img: PIL.Image.Image,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BICUBIC, interpolation: InterpolationMode = InterpolationMode.BICUBIC,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill) return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
...@@ -726,21 +767,25 @@ def perspective( ...@@ -726,21 +767,25 @@ def perspective(
inpt: DType, inpt: DType,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> DType: ) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.perspective(perspective_coeffs, interpolation=interpolation, fill=fill) return inpt.perspective(perspective_coeffs, interpolation=interpolation, fill=fill)
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return perspective_image_pil(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) return perspective_image_pil(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) else:
fill = _convert_fill_arg(fill)
return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
if isinstance(output_size, numbers.Number): if isinstance(output_size, numbers.Number):
return [int(output_size), int(output_size)] return [int(output_size), int(output_size)]
if isinstance(output_size, (tuple, list)) and len(output_size) == 1: elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
return [output_size[0], output_size[0]] return [output_size[0], output_size[0]]
return list(output_size) else:
return list(output_size)
def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]: def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]:
...@@ -810,9 +855,10 @@ def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size: ...@@ -810,9 +855,10 @@ def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size:
def center_crop(inpt: DType, output_size: List[int]) -> DType: def center_crop(inpt: DType, output_size: List[int]) -> DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.center_crop(output_size) return inpt.center_crop(output_size)
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return center_crop_image_pil(inpt, output_size) return center_crop_image_pil(inpt, output_size)
return center_crop_image_tensor(inpt, output_size) else:
return center_crop_image_tensor(inpt, output_size)
def resized_crop_image_tensor( def resized_crop_image_tensor(
...@@ -880,12 +926,13 @@ def resized_crop( ...@@ -880,12 +926,13 @@ def resized_crop(
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
antialias = False if antialias is None else antialias antialias = False if antialias is None else antialias
return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation) return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
if isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation) return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation)
antialias = False if antialias is None else antialias else:
return resized_crop_image_tensor( antialias = False if antialias is None else antialias
inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation return resized_crop_image_tensor(
) inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
)
def _parse_five_crop_size(size: List[int]) -> List[int]: def _parse_five_crop_size(size: List[int]) -> List[int]:
......
import numbers import numbers
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -286,7 +286,7 @@ def affine( ...@@ -286,7 +286,7 @@ def affine(
img: Image.Image, img: Image.Image,
matrix: List[float], matrix: List[float],
interpolation: int = _pil_constants.NEAREST, interpolation: int = _pil_constants.NEAREST,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image: ) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
...@@ -304,7 +304,7 @@ def rotate( ...@@ -304,7 +304,7 @@ def rotate(
interpolation: int = _pil_constants.NEAREST, interpolation: int = _pil_constants.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[Tuple[int, int]] = None, center: Optional[Tuple[int, int]] = None,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image: ) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
...@@ -319,7 +319,7 @@ def perspective( ...@@ -319,7 +319,7 @@ def perspective(
img: Image.Image, img: Image.Image,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: int = _pil_constants.BICUBIC, interpolation: int = _pil_constants.BICUBIC,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image: ) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
......
...@@ -350,7 +350,7 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor: ...@@ -350,7 +350,7 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
raise RuntimeError("Symmetric padding of N-D tensors are not supported yet") raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")
def _parse_pad_padding(padding: List[int]) -> List[int]: def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
if isinstance(padding, int): if isinstance(padding, int):
if torch.jit.is_scripting(): if torch.jit.is_scripting():
# This maybe unreachable # This maybe unreachable
...@@ -370,7 +370,9 @@ def _parse_pad_padding(padding: List[int]) -> List[int]: ...@@ -370,7 +370,9 @@ def _parse_pad_padding(padding: List[int]) -> List[int]:
return [pad_left, pad_right, pad_top, pad_bottom] return [pad_left, pad_right, pad_top, pad_bottom]
def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor: def pad(
img: Tensor, padding: Union[int, List[int]], fill: Union[int, float] = 0, padding_mode: str = "constant"
) -> Tensor:
_assert_image_tensor(img) _assert_image_tensor(img)
if not isinstance(padding, (int, tuple, list)): if not isinstance(padding, (int, tuple, list)):
...@@ -383,8 +385,13 @@ def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mo ...@@ -383,8 +385,13 @@ def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mo
if isinstance(padding, tuple): if isinstance(padding, tuple):
padding = list(padding) padding = list(padding)
if isinstance(padding, list) and len(padding) not in [1, 2, 4]: if isinstance(padding, list):
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple") # TODO: Jit is failing on loading this op when scripted and saved
# https://github.com/pytorch/pytorch/issues/81100
if len(padding) not in [1, 2, 4]:
raise ValueError(
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment