Unverified Commit bd19fb8e authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Added mid-level ops and feature-based ops (#6219)

* Added mid-level ops and feature-based ops

* Fixing deadlock in dataloader with circular imports

* Added non-scalar fill support workaround for pad

* Removed comments

* int/float support for fill in pad op

* Updated type hints and removed bypass option from mid-level methods

* Minor nit fixes
parent b3b74481
...@@ -41,4 +41,4 @@ jobs: ...@@ -41,4 +41,4 @@ jobs:
- name: Run prototype tests - name: Run prototype tests
shell: bash shell: bash
run: pytest --durations=20 test/test_prototype_*.py run: pytest -vvv --durations=20 test/test_prototype_*.py
...@@ -365,6 +365,18 @@ def rotate_segmentation_mask(): ...@@ -365,6 +365,18 @@ def rotate_segmentation_mask():
) )
@register_kernel_info_from_sample_inputs_fn
def crop_image_tensor():
for image, top, left, height, width in itertools.product(make_images(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]):
yield SampleInput(
image,
top=top,
left=left,
height=height,
width=width,
)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def crop_bounding_box(): def crop_bounding_box():
for bounding_box, top, left in itertools.product(make_bounding_boxes(), [-8, 0, 9], [-8, 0, 9]): for bounding_box, top, left in itertools.product(make_bounding_boxes(), [-8, 0, 9], [-8, 0, 9]):
...@@ -414,6 +426,17 @@ def resized_crop_segmentation_mask(): ...@@ -414,6 +426,17 @@ def resized_crop_segmentation_mask():
yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size) yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size)
@register_kernel_info_from_sample_inputs_fn
def pad_image_tensor():
for image, padding, fill, padding_mode in itertools.product(
make_images(),
[[1], [1, 1], [1, 1, 2, 2]], # padding
[12, 12.0], # fill
["constant", "symmetric", "edge", "reflect"], # padding mode,
):
yield SampleInput(image, padding=padding, fill=fill, padding_mode=padding_mode)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def pad_segmentation_mask(): def pad_segmentation_mask():
for mask, padding, padding_mode in itertools.product( for mask, padding, padding_mode in itertools.product(
...@@ -499,6 +522,39 @@ def test_scriptable(kernel): ...@@ -499,6 +522,39 @@ def test_scriptable(kernel):
jit.script(kernel) jit.script(kernel)
# Test below is intended to test mid-level op vs low-level ops it calls
# For example, resize -> resize_image_tensor, resize_bounding_boxes etc
# TODO: Rewrite this tests as sample args may include more or less params
# than needed by functions
@pytest.mark.parametrize(
"func",
[
pytest.param(func, id=name)
for name, func in F.__dict__.items()
if not name.startswith("_")
and callable(func)
and all(
feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"}
)
and name not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av", "crop", "rotate"}
# We skip 'crop' due to missing 'height' and 'width'
# We skip 'rotate' due to non implemented yet expand=True case for bboxes
],
)
def test_functional_mid_level(func):
finfos = [finfo for finfo in FUNCTIONAL_INFOS if f"{func.__name__}_" in finfo.name]
for finfo in finfos:
for sample_input in finfo.sample_inputs():
expected = finfo(sample_input)
kwargs = dict(sample_input.kwargs)
for key in ["format", "image_size"]:
if key in kwargs:
del kwargs[key]
output = func(*sample_input.args, **kwargs)
torch.testing.assert_close(output, expected, msg=f"finfo={finfo}, output={output}, expected={expected}")
break
@pytest.mark.parametrize( @pytest.mark.parametrize(
("functional_info", "sample_input"), ("functional_info", "sample_input"),
[ [
......
from __future__ import annotations from __future__ import annotations
from typing import Any, Tuple, Union, Optional from typing import Any, List, Tuple, Union, Optional, Sequence
import torch import torch
from torchvision._utils import StrEnum from torchvision._utils import StrEnum
from torchvision.transforms import InterpolationMode
from ._feature import _Feature from ._feature import _Feature
...@@ -69,3 +70,142 @@ class BoundingBox(_Feature): ...@@ -69,3 +70,142 @@ class BoundingBox(_Feature):
return BoundingBox.new_like( return BoundingBox.new_like(
self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format
) )
def horizontal_flip(self) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F
output = _F.horizontal_flip_bounding_box(self, format=self.format, image_size=self.image_size)
return BoundingBox.new_like(self, output)
def vertical_flip(self) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F
output = _F.vertical_flip_bounding_box(self, format=self.format, image_size=self.image_size)
return BoundingBox.new_like(self, output)
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: bool = False,
) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F
output = _F.resize_bounding_box(self, size, image_size=self.image_size, max_size=max_size)
image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1])
return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype)
def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F
output = _F.crop_bounding_box(self, self.format, top, left)
return BoundingBox.new_like(self, output, image_size=(height, width))
def center_crop(self, output_size: List[int]) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F
output = _F.center_crop_bounding_box(
self, format=self.format, output_size=output_size, image_size=self.image_size
)
image_size = (output_size[0], output_size[0]) if len(output_size) == 1 else (output_size[0], output_size[1])
return BoundingBox.new_like(self, output, image_size=image_size)
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False,
) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F
output = _F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size)
image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1])
return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype)
def pad(
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F
if padding_mode not in ["constant"]:
raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")
output = _F.pad_bounding_box(self, padding, format=self.format)
# Update output image size:
# TODO: remove the import below and make _parse_pad_padding available
from torchvision.transforms.functional_tensor import _parse_pad_padding
left, top, right, bottom = _parse_pad_padding(padding)
height, width = self.image_size
height += top + bottom
width += left + right
return BoundingBox.new_like(self, output, image_size=(height, width))
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F
output = _F.rotate_bounding_box(
self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center
)
# TODO: update output image size if expand is True
if expand:
raise RuntimeError("Not yet implemented")
return BoundingBox.new_like(self, output, dtype=output.dtype)
def affine(
self,
angle: float,
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F
output = _F.affine_bounding_box(
self,
self.format,
self.image_size,
angle,
translate=translate,
scale=scale,
shear=shear,
center=center,
)
return BoundingBox.new_like(self, output, dtype=output.dtype)
def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F
output = _F.perspective_bounding_box(self, self.format, perspective_coeffs)
return BoundingBox.new_like(self, output, dtype=output.dtype)
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> BoundingBox:
raise TypeError("Erase transformation does not support bounding boxes")
def mixup(self, lam: float) -> BoundingBox:
raise TypeError("Mixup transformation does not support bounding boxes")
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> BoundingBox:
raise TypeError("Cutmix transformation does not support bounding boxes")
from typing import Any, cast, TypeVar, Union, Optional, Type, Callable, Tuple, Sequence, Mapping from typing import Any, cast, TypeVar, Union, Optional, Type, Callable, List, Tuple, Sequence, Mapping
import torch import torch
from torch._C import _TensorBase, DisableTorchFunction from torch._C import _TensorBase, DisableTorchFunction
from torchvision.transforms import InterpolationMode
F = TypeVar("F", bound="_Feature") F = TypeVar("F", bound="_Feature")
...@@ -83,3 +83,115 @@ class _Feature(torch.Tensor): ...@@ -83,3 +83,115 @@ class _Feature(torch.Tensor):
return cls.new_like(args[0], output, dtype=output.dtype, device=output.device) return cls.new_like(args[0], output, dtype=output.dtype, device=output.device)
else: else:
return output return output
def horizontal_flip(self) -> Any:
return self
def vertical_flip(self) -> Any:
return self
# TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize
# https://github.com/pytorch/pytorch/blob/e8727994eb7cdb2ab642749d6549bc497563aa06/torch/_tensor.py#L588-L593
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: bool = False,
) -> Any:
return self
def crop(self, top: int, left: int, height: int, width: int) -> Any:
return self
def center_crop(self, output_size: List[int]) -> Any:
return self
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False,
) -> Any:
return self
def pad(
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
) -> Any:
return self
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> Any:
return self
def affine(
self,
angle: float,
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> Any:
return self
def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> Any:
return self
def adjust_brightness(self, brightness_factor: float) -> Any:
return self
def adjust_saturation(self, saturation_factor: float) -> Any:
return self
def adjust_contrast(self, contrast_factor: float) -> Any:
return self
def adjust_sharpness(self, sharpness_factor: float) -> Any:
return self
def adjust_hue(self, hue_factor: float) -> Any:
return self
def adjust_gamma(self, gamma: float, gain: float = 1) -> Any:
return self
def posterize(self, bits: int) -> Any:
return self
def solarize(self, threshold: float) -> Any:
return self
def autocontrast(self) -> Any:
return self
def equalize(self) -> Any:
return self
def invert(self) -> Any:
return self
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> Any:
return self
def mixup(self, lam: float) -> Any:
return self
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> Any:
return self
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from typing import Any, Optional, Union, Tuple, cast from typing import Any, List, Optional, Union, Sequence, Tuple, cast
import torch import torch
from torchvision._utils import StrEnum from torchvision._utils import StrEnum
from torchvision.transforms.functional import to_pil_image from torchvision.transforms.functional import to_pil_image, InterpolationMode
from torchvision.utils import draw_bounding_boxes from torchvision.utils import draw_bounding_boxes
from torchvision.utils import make_grid from torchvision.utils import make_grid
...@@ -109,3 +109,209 @@ class Image(_Feature): ...@@ -109,3 +109,209 @@ class Image(_Feature):
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state # promote this out of the prototype state
return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs)) return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs))
def horizontal_flip(self) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.horizontal_flip_image_tensor(self)
return Image.new_like(self, output)
def vertical_flip(self) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.vertical_flip_image_tensor(self)
return Image.new_like(self, output)
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: bool = False,
) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.resize_image_tensor(self, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
return Image.new_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.crop_image_tensor(self, top, left, height, width)
return Image.new_like(self, output)
def center_crop(self, output_size: List[int]) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.center_crop_image_tensor(self, output_size=output_size)
return Image.new_like(self, output)
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False,
) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.resized_crop_image_tensor(
self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias
)
return Image.new_like(self, output)
def pad(
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
) -> Image:
from torchvision.prototype.transforms import functional as _F
# PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
if isinstance(fill, (int, float)):
output = _F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
else:
from torchvision.prototype.transforms.functional._geometry import _pad_with_vector_fill
output = _pad_with_vector_fill(self, padding, fill=fill, padding_mode=padding_mode)
return Image.new_like(self, output)
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.rotate_image_tensor(
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
)
return Image.new_like(self, output)
def affine(
self,
angle: float,
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.affine_image_tensor(
self,
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
return Image.new_like(self, output)
def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill)
return Image.new_like(self, output)
def adjust_brightness(self, brightness_factor: float) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.adjust_brightness_image_tensor(self, brightness_factor=brightness_factor)
return Image.new_like(self, output)
def adjust_saturation(self, saturation_factor: float) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.adjust_saturation_image_tensor(self, saturation_factor=saturation_factor)
return Image.new_like(self, output)
def adjust_contrast(self, contrast_factor: float) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.adjust_contrast_image_tensor(self, contrast_factor=contrast_factor)
return Image.new_like(self, output)
def adjust_sharpness(self, sharpness_factor: float) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.adjust_sharpness_image_tensor(self, sharpness_factor=sharpness_factor)
return Image.new_like(self, output)
def adjust_hue(self, hue_factor: float) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.adjust_hue_image_tensor(self, hue_factor=hue_factor)
return Image.new_like(self, output)
def adjust_gamma(self, gamma: float, gain: float = 1) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.adjust_gamma_image_tensor(self, gamma=gamma, gain=gain)
return Image.new_like(self, output)
def posterize(self, bits: int) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.posterize_image_tensor(self, bits=bits)
return Image.new_like(self, output)
def solarize(self, threshold: float) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.solarize_image_tensor(self, threshold=threshold)
return Image.new_like(self, output)
def autocontrast(self) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.autocontrast_image_tensor(self)
return Image.new_like(self, output)
def equalize(self) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.equalize_image_tensor(self)
return Image.new_like(self, output)
def invert(self) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.invert_image_tensor(self)
return Image.new_like(self, output)
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.erase_image_tensor(self, i, j, h, w, v)
return Image.new_like(self, output)
def mixup(self, lam: float) -> Image:
if self.ndim < 4:
raise ValueError("Need a batch of images")
output = self.clone()
output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam))
return Image.new_like(self, output)
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> Image:
if self.ndim < 4:
raise ValueError("Need a batch of images")
x1, y1, x2, y2 = box
image_rolled = self.roll(1, -4)
output = self.clone()
output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
return Image.new_like(self, output)
from __future__ import annotations from __future__ import annotations
from typing import Any, Optional, Sequence, cast, Union from typing import Any, Optional, Sequence, cast, Union, Tuple
import torch import torch
from torchvision.prototype.utils._internal import apply_recursively from torchvision.prototype.utils._internal import apply_recursively
...@@ -77,3 +77,14 @@ class OneHotLabel(_Feature): ...@@ -77,3 +77,14 @@ class OneHotLabel(_Feature):
return super().new_like( return super().new_like(
other, data, categories=categories if categories is not None else other.categories, **kwargs other, data, categories=categories if categories is not None else other.categories, **kwargs
) )
def mixup(self, lam: float) -> OneHotLabel:
if self.ndim < 2:
raise ValueError("Need a batch of one hot labels")
output = self.clone()
output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam))
return OneHotLabel.new_like(self, output)
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> OneHotLabel:
box # unused
return self.mixup(lam_adjusted)
from __future__ import annotations
from typing import Tuple, List, Optional, Union, Sequence
import torch
from torchvision.transforms import InterpolationMode
from ._feature import _Feature from ._feature import _Feature
class SegmentationMask(_Feature): class SegmentationMask(_Feature):
pass def horizontal_flip(self) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F
output = _F.horizontal_flip_segmentation_mask(self)
return SegmentationMask.new_like(self, output)
def vertical_flip(self) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F
output = _F.vertical_flip_segmentation_mask(self)
return SegmentationMask.new_like(self, output)
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
max_size: Optional[int] = None,
antialias: bool = False,
) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F
output = _F.resize_segmentation_mask(self, size, max_size=max_size)
return SegmentationMask.new_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F
output = _F.crop_segmentation_mask(self, top, left, height, width)
return SegmentationMask.new_like(self, output)
def center_crop(self, output_size: List[int]) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F
output = _F.center_crop_segmentation_mask(self, output_size=output_size)
return SegmentationMask.new_like(self, output)
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
antialias: bool = False,
) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F
output = _F.resized_crop_segmentation_mask(self, top, left, height, width, size=size)
return SegmentationMask.new_like(self, output)
def pad(
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F
output = _F.pad_segmentation_mask(self, padding, padding_mode=padding_mode)
return SegmentationMask.new_like(self, output)
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F
output = _F.rotate_segmentation_mask(self, angle, expand=expand, center=center)
return SegmentationMask.new_like(self, output)
def affine(
self,
angle: float,
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F
output = _F.affine_segmentation_mask(
self,
angle,
translate=translate,
scale=scale,
shear=shear,
center=center,
)
return SegmentationMask.new_like(self, output)
def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F
output = _F.perspective_segmentation_mask(self, perspective_coeffs)
return SegmentationMask.new_like(self, output)
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> SegmentationMask:
raise TypeError("Erase transformation does not support segmentation masks")
def mixup(self, lam: float) -> SegmentationMask:
raise TypeError("Mixup transformation does not support segmentation masks")
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> SegmentationMask:
raise TypeError("Cutmix transformation does not support segmentation masks")
...@@ -3,12 +3,13 @@ import numbers ...@@ -3,12 +3,13 @@ import numbers
import warnings import warnings
from typing import Any, Dict, Tuple from typing import Any, Dict, Tuple
import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F from torchvision.prototype.transforms import Transform, functional as F
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import query_image, get_image_dimensions, has_all, has_any, is_simple_tensor from ._utils import query_image, get_image_dimensions, has_all
class RandomErasing(_RandomApplyTransform): class RandomErasing(_RandomApplyTransform):
...@@ -51,7 +52,7 @@ class RandomErasing(_RandomApplyTransform): ...@@ -51,7 +52,7 @@ class RandomErasing(_RandomApplyTransform):
if value is not None and not (len(value) in (1, img_c)): if value is not None and not (len(value) in (1, img_c)):
raise ValueError( raise ValueError(
f"If value is a sequence, it should have either a single value or {img_c} (number of input channels)" f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
) )
area = img_h * img_w area = img_h * img_w
...@@ -82,59 +83,45 @@ class RandomErasing(_RandomApplyTransform): ...@@ -82,59 +83,45 @@ class RandomErasing(_RandomApplyTransform):
else: else:
i, j, h, w, v = 0, 0, img_h, img_w, image i, j, h, w, v = 0, 0, img_h, img_w, image
return dict(zip("ijhwv", (i, j, h, w, v))) return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image): if isinstance(inpt, features._Feature):
output = F.erase_image_tensor(input, **params) return inpt.erase(**params)
return features.Image.new_like(input, output) elif isinstance(inpt, PIL.Image.Image):
elif is_simple_tensor(input): # TODO: We should implement a fallback to tensor, like gaussian_blur etc
return F.erase_image_tensor(input, **params) raise RuntimeError("Not implemented")
elif isinstance(inpt, torch.Tensor):
return F.erase_image_tensor(inpt, **params)
else: else:
return input return inpt
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample) class _BaseMixupCutmix(Transform):
class RandomMixup(Transform):
def __init__(self, *, alpha: float) -> None: def __init__(self, *, alpha: float) -> None:
super().__init__() super().__init__()
self.alpha = alpha self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def forward(self, *inpts: Any) -> Any:
sample = inpts if len(inpts) > 1 else inpts[0]
if not has_all(sample, features.Image, features.OneHotLabel):
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
return super().forward(sample)
class RandomMixup(_BaseMixupCutmix):
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(()))) return dict(lam=float(self._dist.sample(())))
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image): if isinstance(inpt, features._Feature):
output = F.mixup_image_tensor(input, **params) return inpt.mixup(**params)
return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel):
output = F.mixup_one_hot_label(input, **params)
return features.OneHotLabel.new_like(input, output)
else: else:
return input return inpt
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
elif not has_all(sample, features.Image, features.OneHotLabel):
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
return super().forward(sample)
class RandomCutmix(Transform): class RandomCutmix(_BaseMixupCutmix):
def __init__(self, *, alpha: float) -> None:
super().__init__()
self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
lam = float(self._dist.sample(())) lam = float(self._dist.sample(()))
...@@ -158,20 +145,8 @@ class RandomCutmix(Transform): ...@@ -158,20 +145,8 @@ class RandomCutmix(Transform):
return dict(box=box, lam_adjusted=lam_adjusted) return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image): if isinstance(inpt, features._Feature):
output = F.cutmix_image_tensor(input, box=params["box"]) return inpt.cutmix(**params)
return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel):
output = F.cutmix_one_hot_label(input, lam_adjusted=params["lam_adjusted"])
return features.OneHotLabel.new_like(input, output)
else: else:
return input return inpt
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
elif not has_all(sample, features.Image, features.OneHotLabel):
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
return super().forward(sample)
...@@ -7,72 +7,89 @@ from ._meta import ( ...@@ -7,72 +7,89 @@ from ._meta import (
from ._augment import ( from ._augment import (
erase_image_tensor, erase_image_tensor,
mixup_image_tensor,
mixup_one_hot_label,
cutmix_image_tensor,
cutmix_one_hot_label,
) )
from ._color import ( from ._color import (
adjust_brightness,
adjust_brightness_image_tensor, adjust_brightness_image_tensor,
adjust_brightness_image_pil, adjust_brightness_image_pil,
adjust_contrast,
adjust_contrast_image_tensor, adjust_contrast_image_tensor,
adjust_contrast_image_pil, adjust_contrast_image_pil,
adjust_saturation,
adjust_saturation_image_tensor, adjust_saturation_image_tensor,
adjust_saturation_image_pil, adjust_saturation_image_pil,
adjust_sharpness,
adjust_sharpness_image_tensor, adjust_sharpness_image_tensor,
adjust_sharpness_image_pil, adjust_sharpness_image_pil,
adjust_hue,
adjust_hue_image_tensor,
adjust_hue_image_pil,
adjust_gamma,
adjust_gamma_image_tensor,
adjust_gamma_image_pil,
posterize,
posterize_image_tensor, posterize_image_tensor,
posterize_image_pil, posterize_image_pil,
solarize,
solarize_image_tensor, solarize_image_tensor,
solarize_image_pil, solarize_image_pil,
autocontrast,
autocontrast_image_tensor, autocontrast_image_tensor,
autocontrast_image_pil, autocontrast_image_pil,
equalize,
equalize_image_tensor, equalize_image_tensor,
equalize_image_pil, equalize_image_pil,
invert,
invert_image_tensor, invert_image_tensor,
invert_image_pil, invert_image_pil,
adjust_hue_image_tensor,
adjust_hue_image_pil,
adjust_gamma_image_tensor,
adjust_gamma_image_pil,
) )
from ._geometry import ( from ._geometry import (
horizontal_flip,
horizontal_flip_bounding_box, horizontal_flip_bounding_box,
horizontal_flip_image_tensor, horizontal_flip_image_tensor,
horizontal_flip_image_pil, horizontal_flip_image_pil,
horizontal_flip_segmentation_mask, horizontal_flip_segmentation_mask,
resize,
resize_bounding_box, resize_bounding_box,
resize_image_tensor, resize_image_tensor,
resize_image_pil, resize_image_pil,
resize_segmentation_mask, resize_segmentation_mask,
center_crop,
center_crop_bounding_box, center_crop_bounding_box,
center_crop_segmentation_mask, center_crop_segmentation_mask,
center_crop_image_tensor, center_crop_image_tensor,
center_crop_image_pil, center_crop_image_pil,
resized_crop,
resized_crop_bounding_box, resized_crop_bounding_box,
resized_crop_image_tensor, resized_crop_image_tensor,
resized_crop_image_pil, resized_crop_image_pil,
resized_crop_segmentation_mask, resized_crop_segmentation_mask,
affine,
affine_bounding_box, affine_bounding_box,
affine_image_tensor, affine_image_tensor,
affine_image_pil, affine_image_pil,
affine_segmentation_mask, affine_segmentation_mask,
rotate,
rotate_bounding_box, rotate_bounding_box,
rotate_image_tensor, rotate_image_tensor,
rotate_image_pil, rotate_image_pil,
rotate_segmentation_mask, rotate_segmentation_mask,
pad,
pad_bounding_box, pad_bounding_box,
pad_image_tensor, pad_image_tensor,
pad_image_pil, pad_image_pil,
pad_segmentation_mask, pad_segmentation_mask,
crop,
crop_bounding_box, crop_bounding_box,
crop_image_tensor, crop_image_tensor,
crop_image_pil, crop_image_pil,
crop_segmentation_mask, crop_segmentation_mask,
perspective,
perspective_bounding_box, perspective_bounding_box,
perspective_image_tensor, perspective_image_tensor,
perspective_image_pil, perspective_image_pil,
perspective_segmentation_mask, perspective_segmentation_mask,
vertical_flip,
vertical_flip_image_tensor, vertical_flip_image_tensor,
vertical_flip_image_pil, vertical_flip_image_pil,
vertical_flip_bounding_box, vertical_flip_bounding_box,
......
from typing import Tuple
import torch
from torchvision.transforms import functional_tensor as _FT from torchvision.transforms import functional_tensor as _FT
erase_image_tensor = _FT.erase erase_image_tensor = _FT.erase
def _mixup_tensor(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: # TODO: Don't forget to clean up from the primitives kernels those that shouldn't be kernels.
input = input.clone() # Like the mixup and cutmix stuff
return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam))
def mixup_image_tensor(image_batch: torch.Tensor, *, lam: float) -> torch.Tensor:
if image_batch.ndim < 4:
raise ValueError("Need a batch of images")
return _mixup_tensor(image_batch, -4, lam)
def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float) -> torch.Tensor:
if one_hot_label_batch.ndim < 2:
raise ValueError("Need a batch of one hot labels")
return _mixup_tensor(one_hot_label_batch, -2, lam)
def cutmix_image_tensor(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int]) -> torch.Tensor:
if image_batch.ndim < 4:
raise ValueError("Need a batch of images")
x1, y1, x2, y2 = box
image_rolled = image_batch.roll(1, -4)
image_batch = image_batch.clone()
image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
return image_batch
def cutmix_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam_adjusted: float) -> torch.Tensor:
if one_hot_label_batch.ndim < 2:
raise ValueError("Need a batch of one hot labels")
return _mixup_tensor(one_hot_label_batch, -2, lam_adjusted) # This function is copy-pasted to Image and OneHotLabel and may be refactored
# def _mixup_tensor(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor:
# input = input.clone()
# return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam))
from typing import Union
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
# shortcut type
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
adjust_brightness_image_tensor = _FT.adjust_brightness adjust_brightness_image_tensor = _FT.adjust_brightness
adjust_brightness_image_pil = _FP.adjust_brightness adjust_brightness_image_pil = _FP.adjust_brightness
def adjust_brightness(inpt: DType, brightness_factor: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_brightness(brightness_factor=brightness_factor)
if isinstance(inpt, PIL.Image.Image):
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
adjust_saturation_image_tensor = _FT.adjust_saturation adjust_saturation_image_tensor = _FT.adjust_saturation
adjust_saturation_image_pil = _FP.adjust_saturation adjust_saturation_image_pil = _FP.adjust_saturation
def adjust_saturation(inpt: DType, saturation_factor: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_saturation(saturation_factor=saturation_factor)
if isinstance(inpt, PIL.Image.Image):
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
adjust_contrast_image_tensor = _FT.adjust_contrast adjust_contrast_image_tensor = _FT.adjust_contrast
adjust_contrast_image_pil = _FP.adjust_contrast adjust_contrast_image_pil = _FP.adjust_contrast
def adjust_contrast(inpt: DType, contrast_factor: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_contrast(contrast_factor=contrast_factor)
if isinstance(inpt, PIL.Image.Image):
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
adjust_sharpness_image_tensor = _FT.adjust_sharpness adjust_sharpness_image_tensor = _FT.adjust_sharpness
adjust_sharpness_image_pil = _FP.adjust_sharpness adjust_sharpness_image_pil = _FP.adjust_sharpness
def adjust_sharpness(inpt: DType, sharpness_factor: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_sharpness(sharpness_factor=sharpness_factor)
if isinstance(inpt, PIL.Image.Image):
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
adjust_hue_image_tensor = _FT.adjust_hue
adjust_hue_image_pil = _FP.adjust_hue
def adjust_hue(inpt: DType, hue_factor: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_hue(hue_factor=hue_factor)
if isinstance(inpt, PIL.Image.Image):
return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
adjust_gamma_image_tensor = _FT.adjust_gamma
adjust_gamma_image_pil = _FP.adjust_gamma
def adjust_gamma(inpt: DType, gamma: float, gain: float = 1) -> DType:
if isinstance(inpt, features._Feature):
return inpt.adjust_gamma(gamma=gamma, gain=gain)
if isinstance(inpt, PIL.Image.Image):
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
posterize_image_tensor = _FT.posterize posterize_image_tensor = _FT.posterize
posterize_image_pil = _FP.posterize posterize_image_pil = _FP.posterize
def posterize(inpt: DType, bits: int) -> DType:
if isinstance(inpt, features._Feature):
return inpt.posterize(bits=bits)
if isinstance(inpt, PIL.Image.Image):
return posterize_image_pil(inpt, bits=bits)
return posterize_image_tensor(inpt, bits=bits)
solarize_image_tensor = _FT.solarize solarize_image_tensor = _FT.solarize
solarize_image_pil = _FP.solarize solarize_image_pil = _FP.solarize
def solarize(inpt: DType, threshold: float) -> DType:
if isinstance(inpt, features._Feature):
return inpt.solarize(threshold=threshold)
if isinstance(inpt, PIL.Image.Image):
return solarize_image_pil(inpt, threshold=threshold)
return solarize_image_tensor(inpt, threshold=threshold)
autocontrast_image_tensor = _FT.autocontrast autocontrast_image_tensor = _FT.autocontrast
autocontrast_image_pil = _FP.autocontrast autocontrast_image_pil = _FP.autocontrast
def autocontrast(inpt: DType) -> DType:
if isinstance(inpt, features._Feature):
return inpt.autocontrast()
if isinstance(inpt, PIL.Image.Image):
return autocontrast_image_pil(inpt)
return autocontrast_image_tensor(inpt)
equalize_image_tensor = _FT.equalize equalize_image_tensor = _FT.equalize
equalize_image_pil = _FP.equalize equalize_image_pil = _FP.equalize
def equalize(inpt: DType) -> DType:
if isinstance(inpt, features._Feature):
return inpt.equalize()
if isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt)
return equalize_image_tensor(inpt)
invert_image_tensor = _FT.invert invert_image_tensor = _FT.invert
invert_image_pil = _FP.invert invert_image_pil = _FP.invert
adjust_hue_image_tensor = _FT.adjust_hue
adjust_hue_image_pil = _FP.adjust_hue
adjust_gamma_image_tensor = _FT.adjust_gamma def invert(inpt: DType) -> DType:
adjust_gamma_image_pil = _FP.adjust_gamma if isinstance(inpt, features._Feature):
return inpt.invert()
if isinstance(inpt, PIL.Image.Image):
return invert_image_pil(inpt)
return invert_image_tensor(inpt)
...@@ -16,6 +16,10 @@ from torchvision.transforms.functional import ( ...@@ -16,6 +16,10 @@ from torchvision.transforms.functional import (
from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil
# shortcut type
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
horizontal_flip_image_tensor = _FT.hflip horizontal_flip_image_tensor = _FT.hflip
horizontal_flip_image_pil = _FP.hflip horizontal_flip_image_pil = _FP.hflip
...@@ -40,12 +44,52 @@ def horizontal_flip_bounding_box( ...@@ -40,12 +44,52 @@ def horizontal_flip_bounding_box(
).view(shape) ).view(shape)
def horizontal_flip(inpt: DType) -> DType:
if isinstance(inpt, features._Feature):
return inpt.horizontal_flip()
if isinstance(inpt, PIL.Image.Image):
return horizontal_flip_image_pil(inpt)
return horizontal_flip_image_tensor(inpt)
vertical_flip_image_tensor = _FT.vflip
vertical_flip_image_pil = _FP.vflip
def vertical_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor:
return vertical_flip_image_tensor(segmentation_mask)
def vertical_flip_bounding_box(
bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int]
) -> torch.Tensor:
shape = bounding_box.shape
bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
bounding_box[:, [1, 3]] = image_size[0] - bounding_box[:, [3, 1]]
return convert_bounding_box_format(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(shape)
def vertical_flip(inpt: DType) -> DType:
if isinstance(inpt, features._Feature):
return inpt.vertical_flip()
if isinstance(inpt, PIL.Image.Image):
return vertical_flip_image_pil(inpt)
return vertical_flip_image_tensor(inpt)
def resize_image_tensor( def resize_image_tensor(
image: torch.Tensor, image: torch.Tensor,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[bool] = None, antialias: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
num_channels, old_height, old_width = get_dimensions_image_tensor(image) num_channels, old_height, old_width = get_dimensions_image_tensor(image)
new_height, new_width = _compute_output_size((old_height, old_width), size=size, max_size=max_size) new_height, new_width = _compute_output_size((old_height, old_width), size=size, max_size=max_size)
...@@ -87,28 +131,23 @@ def resize_bounding_box( ...@@ -87,28 +131,23 @@ def resize_bounding_box(
return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape)
vertical_flip_image_tensor = _FT.vflip def resize(
vertical_flip_image_pil = _FP.vflip inpt: DType,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
def vertical_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor: max_size: Optional[int] = None,
return vertical_flip_image_tensor(segmentation_mask) antialias: Optional[bool] = None,
) -> DType:
if isinstance(inpt, features._Feature):
def vertical_flip_bounding_box( antialias = False if antialias is None else antialias
bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int] return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
) -> torch.Tensor: if isinstance(inpt, PIL.Image.Image):
shape = bounding_box.shape if antialias is not None and not antialias:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
bounding_box = convert_bounding_box_format( return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size)
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
bounding_box[:, [1, 3]] = image_size[0] - bounding_box[:, [3, 1]]
return convert_bounding_box_format( antialias = False if antialias is None else antialias
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
).view(shape)
def _affine_parse_args( def _affine_parse_args(
...@@ -323,6 +362,43 @@ def affine_segmentation_mask( ...@@ -323,6 +362,43 @@ def affine_segmentation_mask(
) )
def affine(
inpt: DType,
angle: float,
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> DType:
if isinstance(inpt, features._Feature):
return inpt.affine(
angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center
)
if isinstance(inpt, PIL.Image.Image):
return affine_image_pil(
inpt,
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
return affine_image_tensor(
inpt,
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
def rotate_image_tensor( def rotate_image_tensor(
img: torch.Tensor, img: torch.Tensor,
angle: float, angle: float,
...@@ -402,10 +478,63 @@ def rotate_segmentation_mask( ...@@ -402,10 +478,63 @@ def rotate_segmentation_mask(
) )
pad_image_tensor = _FT.pad def rotate(
inpt: DType,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> DType:
if isinstance(inpt, features._Feature):
return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
if isinstance(inpt, PIL.Image.Image):
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
pad_image_pil = _FP.pad pad_image_pil = _FP.pad
def pad_image_tensor(
img: torch.Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant"
) -> torch.Tensor:
num_masks, height, width = img.shape[-3:]
extra_dims = img.shape[:-3]
padded_image = _FT.pad(
img=img.view(-1, num_masks, height, width), padding=padding, fill=fill, padding_mode=padding_mode
)
new_height, new_width = padded_image.shape[-2:]
return padded_image.view(extra_dims + (num_masks, new_height, new_width))
# TODO: This should be removed once pytorch pad supports non-scalar padding values
def _pad_with_vector_fill(
img: torch.Tensor,
padding: List[int],
fill: Sequence[float] = [0.0],
padding_mode: str = "constant",
) -> torch.Tensor:
if padding_mode != "constant":
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
output = pad_image_tensor(img, padding, fill=0, padding_mode="constant")
left, top, right, bottom = padding
fill = torch.tensor(fill, dtype=img.dtype, device=img.device).view(-1, 1, 1)
if top > 0:
output[..., :top, :] = fill
if left > 0:
output[..., :, :left] = fill
if bottom > 0:
output[..., -bottom:, :] = fill
if right > 0:
output[..., :, -right:] = fill
return output
def pad_segmentation_mask( def pad_segmentation_mask(
segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant" segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant"
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -436,6 +565,21 @@ def pad_bounding_box( ...@@ -436,6 +565,21 @@ def pad_bounding_box(
return bounding_box return bounding_box
def pad(
inpt: DType, padding: List[int], fill: Union[int, float, Sequence[float]] = 0.0, padding_mode: str = "constant"
) -> DType:
if isinstance(inpt, features._Feature):
return inpt.pad(padding, fill=fill, padding_mode=padding_mode)
if isinstance(inpt, PIL.Image.Image):
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
# TODO: PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
if isinstance(fill, (int, float)):
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode)
crop_image_tensor = _FT.crop crop_image_tensor = _FT.crop
crop_image_pil = _FP.crop crop_image_pil = _FP.crop
...@@ -463,6 +607,14 @@ def crop_segmentation_mask(img: torch.Tensor, top: int, left: int, height: int, ...@@ -463,6 +607,14 @@ def crop_segmentation_mask(img: torch.Tensor, top: int, left: int, height: int,
return crop_image_tensor(img, top, left, height, width) return crop_image_tensor(img, top, left, height, width)
def crop(inpt: DType, top: int, left: int, height: int, width: int) -> DType:
if isinstance(inpt, features._Feature):
return inpt.crop(top, left, height, width)
if isinstance(inpt, PIL.Image.Image):
return crop_image_pil(inpt, top, left, height, width)
return crop_image_tensor(inpt, top, left, height, width)
def perspective_image_tensor( def perspective_image_tensor(
img: torch.Tensor, img: torch.Tensor,
perspective_coeffs: List[float], perspective_coeffs: List[float],
...@@ -474,7 +626,7 @@ def perspective_image_tensor( ...@@ -474,7 +626,7 @@ def perspective_image_tensor(
def perspective_image_pil( def perspective_image_pil(
img: PIL.Image.Image, img: PIL.Image.Image,
perspective_coeffs: float, perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BICUBIC, interpolation: InterpolationMode = InterpolationMode.BICUBIC,
fill: Optional[List[float]] = None, fill: Optional[List[float]] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
...@@ -570,13 +722,25 @@ def perspective_segmentation_mask(img: torch.Tensor, perspective_coeffs: List[fl ...@@ -570,13 +722,25 @@ def perspective_segmentation_mask(img: torch.Tensor, perspective_coeffs: List[fl
return perspective_image_tensor(img, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST) return perspective_image_tensor(img, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST)
def perspective(
inpt: DType,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> DType:
if isinstance(inpt, features._Feature):
return inpt.perspective(perspective_coeffs, interpolation=interpolation, fill=fill)
if isinstance(inpt, PIL.Image.Image):
return perspective_image_pil(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
if isinstance(output_size, numbers.Number): if isinstance(output_size, numbers.Number):
return [int(output_size), int(output_size)] return [int(output_size), int(output_size)]
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: if isinstance(output_size, (tuple, list)) and len(output_size) == 1:
return [output_size[0], output_size[0]] return [output_size[0], output_size[0]]
else: return list(output_size)
return list(output_size)
def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]: def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]:
...@@ -643,6 +807,14 @@ def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size: ...@@ -643,6 +807,14 @@ def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size:
return center_crop_image_tensor(img=segmentation_mask, output_size=output_size) return center_crop_image_tensor(img=segmentation_mask, output_size=output_size)
def center_crop(inpt: DType, output_size: List[int]) -> DType:
if isinstance(inpt, features._Feature):
return inpt.center_crop(output_size)
if isinstance(inpt, PIL.Image.Image):
return center_crop_image_pil(inpt, output_size)
return center_crop_image_tensor(inpt, output_size)
def resized_crop_image_tensor( def resized_crop_image_tensor(
img: torch.Tensor, img: torch.Tensor,
top: int, top: int,
...@@ -651,9 +823,10 @@ def resized_crop_image_tensor( ...@@ -651,9 +823,10 @@ def resized_crop_image_tensor(
width: int, width: int,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
img = crop_image_tensor(img, top, left, height, width) img = crop_image_tensor(img, top, left, height, width)
return resize_image_tensor(img, size, interpolation=interpolation) return resize_image_tensor(img, size, interpolation=interpolation, antialias=antialias)
def resized_crop_image_pil( def resized_crop_image_pil(
...@@ -694,6 +867,27 @@ def resized_crop_segmentation_mask( ...@@ -694,6 +867,27 @@ def resized_crop_segmentation_mask(
return resize_segmentation_mask(mask, size) return resize_segmentation_mask(mask, size)
def resized_crop(
inpt: DType,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> DType:
if isinstance(inpt, features._Feature):
antialias = False if antialias is None else antialias
return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
if isinstance(inpt, PIL.Image.Image):
return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation)
antialias = False if antialias is None else antialias
return resized_crop_image_tensor(
inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
)
def _parse_five_crop_size(size: List[int]) -> List[int]: def _parse_five_crop_size(size: List[int]) -> List[int]:
if isinstance(size, numbers.Number): if isinstance(size, numbers.Number):
size = [int(size), int(size)] size = [int(size), int(size)]
......
...@@ -317,9 +317,9 @@ def rotate( ...@@ -317,9 +317,9 @@ def rotate(
@torch.jit.unused @torch.jit.unused
def perspective( def perspective(
img: Image.Image, img: Image.Image,
perspective_coeffs: float, perspective_coeffs: List[float],
interpolation: int = _pil_constants.BICUBIC, interpolation: int = _pil_constants.BICUBIC,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0, fill: Optional[Union[float, List[float], Tuple[float, ...]]] = None,
) -> Image.Image: ) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment