Unverified Commit 54dd0a59 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Prototype Transform cleanups and bugfixing (#6512)



* clean up bboxes

* Correcting bug on RandAugment.

* Fix bug on erase()

* Moving `_parse_pad_padding()` to functional geometry

* Fixing bug on deprecated Grayscale transforms

* Porting old bugfixes from main branch to _apply_image_transform

* Dropping mandatory keyword arguments to maintain BC.

* Adding antialias option where possible.

* Specifying types in `_transform()` where possible.

* Add todo.

* Fixing tests.

* Adding padding_mode in pad_bbox
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent 26099237
......@@ -1272,8 +1272,11 @@ class TestScaleJitter:
def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock()
antialias_sentinel = mocker.MagicMock()
transform = transforms.ScaleJitter(target_size=(16, 12), interpolation=interpolation_sentinel)
transform = transforms.ScaleJitter(
target_size=(16, 12), interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
transform._transformed_types = (mocker.MagicMock,)
size_sentinel = mocker.MagicMock()
......@@ -1286,7 +1289,9 @@ class TestScaleJitter:
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
transform(inpt_sentinel)
mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)
mock.assert_called_once_with(
inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
class TestRandomShortestSize:
......@@ -1316,8 +1321,11 @@ class TestRandomShortestSize:
def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock()
antialias_sentinel = mocker.MagicMock()
transform = transforms.RandomShortestSize(min_size=[3, 5, 7], max_size=12, interpolation=interpolation_sentinel)
transform = transforms.RandomShortestSize(
min_size=[3, 5, 7], max_size=12, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
transform._transformed_types = (mocker.MagicMock,)
size_sentinel = mocker.MagicMock()
......@@ -1331,7 +1339,9 @@ class TestRandomShortestSize:
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
transform(inpt_sentinel)
mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)
mock.assert_called_once_with(
inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
class TestSimpleCopyPaste:
......@@ -1404,6 +1414,9 @@ class TestSimpleCopyPaste:
masks[0, 3:9, 2:8] = 1
masks[1, 20:30, 20:30] = 1
labels = torch.tensor([1, 2])
blending = True
resize_interpolation = InterpolationMode.BILINEAR
antialias = None
if label_type == features.OneHotLabel:
labels = torch.nn.functional.one_hot(labels, num_classes=5)
target = {
......@@ -1431,7 +1444,9 @@ class TestSimpleCopyPaste:
transform = transforms.SimpleCopyPaste()
random_selection = torch.tensor([0, 1])
output_image, output_target = transform._copy_paste(image, target, paste_image, paste_target, random_selection)
output_image, output_target = transform._copy_paste(
image, target, paste_image, paste_target, random_selection, blending, resize_interpolation, antialias
)
assert output_image.unique().tolist() == [2, 10]
assert output_target["boxes"].shape == (4, 4)
......
import math
import numbers
import warnings
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
import PIL.Image
import torch
......@@ -15,6 +15,8 @@ from ._utils import has_any, query_chw
class RandomErasing(_RandomApplyTransform):
_transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image)
def __init__(
self,
p: float = 0.5,
......@@ -86,7 +88,9 @@ class RandomErasing(_RandomApplyTransform):
return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
def _transform(
self, inpt: Union[torch.Tensor, features.Image, PIL.Image.Image], params: Dict[str, Any]
) -> Union[torch.Tensor, features.Image, PIL.Image.Image]:
if params["v"] is not None:
inpt = F.erase(inpt, **params)
......@@ -94,7 +98,7 @@ class RandomErasing(_RandomApplyTransform):
class _BaseMixupCutmix(_RandomApplyTransform):
def __init__(self, *, alpha: float, p: float = 0.5) -> None:
def __init__(self, alpha: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
......@@ -188,10 +192,12 @@ class SimpleCopyPaste(_RandomApplyTransform):
p: float = 0.5,
blending: bool = True,
resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> None:
super().__init__(p=p)
self.resize_interpolation = resize_interpolation
self.blending = blending
self.antialias = antialias
def _copy_paste(
self,
......@@ -200,8 +206,9 @@ class SimpleCopyPaste(_RandomApplyTransform):
paste_image: Any,
paste_target: Dict[str, Any],
random_selection: torch.Tensor,
blending: bool = True,
resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
blending: bool,
resize_interpolation: F.InterpolationMode,
antialias: Optional[bool],
) -> Tuple[Any, Dict[str, Any]]:
paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection])
......@@ -217,7 +224,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
size1 = image.shape[-2:]
size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation)
paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation, antialias=antialias)
paste_masks = F.resize(paste_masks, size=size1)
paste_boxes = F.resize(paste_boxes, size=size1)
......@@ -356,6 +363,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
random_selection=random_selection,
blending=self.blending,
resize_interpolation=self.resize_interpolation,
antialias=self.antialias,
)
output_images.append(output_image)
output_targets.append(output_target)
......
......@@ -116,8 +116,8 @@ class _AutoAugmentBase(Transform):
angle=0.0,
translate=[int(magnitude), 0],
scale=1.0,
shear=[0.0, 0.0],
interpolation=interpolation,
shear=[0.0, 0.0],
fill=fill_,
)
elif transform_id == "TranslateY":
......@@ -126,8 +126,8 @@ class _AutoAugmentBase(Transform):
angle=0.0,
translate=[0, int(magnitude)],
scale=1.0,
shear=[0.0, 0.0],
interpolation=interpolation,
shear=[0.0, 0.0],
fill=fill_,
)
elif transform_id == "Rotate":
......
......@@ -112,7 +112,7 @@ class RandomPhotometricDistort(Transform):
channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None,
)
def _permute_channels(self, inpt: Any, *, permutation: torch.Tensor) -> Any:
def _permute_channels(self, inpt: Any, permutation: torch.Tensor) -> Any:
if isinstance(inpt, PIL.Image.Image):
inpt = F.to_image_tensor(inpt)
......@@ -125,7 +125,9 @@ class RandomPhotometricDistort(Transform):
return output
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
def _transform(
self, inpt: Union[torch.Tensor, features.Image, PIL.Image.Image], params: Dict[str, Any]
) -> Union[torch.Tensor, features.Image, PIL.Image.Image]:
if params["brightness"]:
inpt = F.adjust_brightness(
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
......
......@@ -22,7 +22,7 @@ class Compose(Transform):
class RandomApply(_RandomApplyTransform):
def __init__(self, transform: Transform, *, p: float = 0.5) -> None:
def __init__(self, transform: Transform, p: float = 0.5) -> None:
super().__init__(p=p)
self.transform = transform
......
import warnings
from typing import Any, Dict
from typing import Any, Dict, Union
import numpy as np
import PIL.Image
......@@ -14,6 +14,9 @@ from ._transform import _RandomApplyTransform
from ._utils import query_chw
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
class ToTensor(Transform):
_transformed_types = (PIL.Image.Image, np.ndarray)
......@@ -24,7 +27,7 @@ class ToTensor(Transform):
)
super().__init__()
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor:
return _F.to_tensor(inpt)
......@@ -52,8 +55,11 @@ class Grayscale(Transform):
super().__init__()
self.num_output_channels = num_output_channels
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
def _transform(self, inpt: DType, params: Dict[str, Any]) -> DType:
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY)
return output
class RandomGrayscale(_RandomApplyTransform):
......@@ -78,5 +84,8 @@ class RandomGrayscale(_RandomApplyTransform):
num_input_channels, _, _ = query_chw(sample)
return dict(num_input_channels=num_input_channels)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
def _transform(self, inpt: DType, params: Dict[str, Any]) -> DType:
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY)
return output
......@@ -15,6 +15,9 @@ from ._transform import _RandomApplyTransform
from ._utils import _check_sequence_input, _setup_angle, _setup_size, has_all, has_any, query_bounding_box, query_chw
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
class RandomHorizontalFlip(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.horizontal_flip(inpt)
......@@ -163,7 +166,7 @@ class FiveCrop(Transform):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
def _transform(self, inpt: DType, params: Dict[str, Any]) -> Tuple[DType, DType, DType, DType, DType]:
return F.five_crop(inpt, self.size)
def forward(self, *inputs: Any) -> Any:
......@@ -184,7 +187,7 @@ class TenCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
def _transform(self, inpt: DType, params: Dict[str, Any]) -> List[DType]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
def forward(self, *inputs: Any) -> Any:
......@@ -713,11 +716,13 @@ class ScaleJitter(Transform):
target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
):
super().__init__()
self.target_size = target_size
self.scale_range = scale_range
self.interpolation = interpolation
self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]:
_, orig_height, orig_width = query_chw(sample)
......@@ -729,7 +734,7 @@ class ScaleJitter(Transform):
return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation)
return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias)
class RandomShortestSize(Transform):
......@@ -738,11 +743,13 @@ class RandomShortestSize(Transform):
min_size: Union[List[int], Tuple[int], int],
max_size: int,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
):
super().__init__()
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
self.max_size = max_size
self.interpolation = interpolation
self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]:
_, orig_height, orig_width = query_chw(sample)
......@@ -756,7 +763,7 @@ class RandomShortestSize(Transform):
return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation)
return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias)
class FixedSizeCrop(Transform):
......
......@@ -16,7 +16,7 @@ class ConvertBoundingBoxFormat(Transform):
format = features.BoundingBoxFormat[format]
self.format = format
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox:
output = F.convert_bounding_box_format(inpt, old_format=inpt.format, new_format=params["format"])
return features.BoundingBox.new_like(inpt, output, format=params["format"])
......@@ -28,9 +28,11 @@ class ConvertImageDtype(Transform):
super().__init__()
self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
def _transform(
self, inpt: Union[torch.Tensor, features.Image], params: Dict[str, Any]
) -> Union[torch.Tensor, features.Image]:
output = F.convert_image_dtype(inpt, dtype=self.dtype)
return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype)
return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype) # type: ignore[arg-type]
class ConvertColorSpace(Transform):
......@@ -54,7 +56,9 @@ class ConvertColorSpace(Transform):
self.copy = copy
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, features._Feature], params: Dict[str, Any]
) -> Union[torch.Tensor, PIL.Image.Image, features._Feature]:
return F.convert_color_space(
inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy
)
......@@ -63,6 +67,6 @@ class ConvertColorSpace(Transform):
class ClampBoundingBoxes(Transform):
_transformed_types = (features.BoundingBox,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox:
output = F.clamp_bounding_box(inpt, format=inpt.format, image_size=inpt.image_size)
return features.BoundingBox.new_like(inpt, output)
......@@ -68,7 +68,7 @@ class LinearTransformation(Transform):
return super().forward(*inputs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
def _transform(self, inpt: Union[torch.Tensor, features._Feature], params: Dict[str, Any]) -> torch.Tensor:
# Image instance after linear transformation is not Image anymore due to unknown data range
# Thus we will return Tensor for input Image
......@@ -100,7 +100,7 @@ class Normalize(Transform):
self.mean = list(mean)
self.std = list(std)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
def _transform(self, inpt: Union[torch.Tensor, features._Feature], params: Dict[str, Any]) -> torch.Tensor:
return F.normalize(inpt, mean=self.mean, std=self.std)
def forward(self, *inpts: Any) -> Any:
......
......@@ -56,7 +56,7 @@ class Transform(nn.Module):
class _RandomApplyTransform(Transform):
def __init__(self, *, p: float = 0.5) -> None:
def __init__(self, p: float = 0.5) -> None:
if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
......
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
import numpy as np
import PIL.Image
import torch
from torch.nn.functional import one_hot
from torchvision.prototype import features
......@@ -11,7 +12,7 @@ from torchvision.prototype.transforms import functional as F, Transform
class DecodeImage(Transform):
_transformed_types = (features.EncodedImage,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image:
def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> features.Image:
return F.decode_image_with_pil(inpt)
......@@ -39,18 +40,22 @@ class LabelToOneHot(Transform):
class ToImageTensor(Transform):
_transformed_types = (features.is_simple_tensor, PIL.Image.Image, np.ndarray)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image:
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> features.Image:
return F.to_image_tensor(inpt)
class ToImagePIL(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, np.ndarray)
def __init__(self, *, mode: Optional[str] = None) -> None:
def __init__(self, mode: Optional[str] = None) -> None:
super().__init__()
self.mode = mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image.Image:
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> PIL.Image.Image:
return F.to_image_pil(inpt, mode=self.mode)
......
# TODO: Add _log_api_usage_once() in all mid-level kernels. If they remain not jit-scriptable we can use decorators
from torchvision.transforms import InterpolationMode # usort: skip
from ._meta import (
clamp_bounding_box,
......
from typing import Any
from typing import Union
import PIL.Image
......@@ -19,7 +19,15 @@ def erase_image_pil(
return to_pil_image(output, mode=img.mode)
def erase(inpt: Any, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False) -> Any:
def erase(
inpt: Union[torch.Tensor, PIL.Image.Image, features.Image],
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> Union[torch.Tensor, PIL.Image.Image, features.Image]:
if isinstance(inpt, torch.Tensor):
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
if isinstance(inpt, features.Image):
......
import warnings
from typing import Any
from typing import Any, Union
import PIL.Image
import torch
......@@ -21,7 +21,9 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
return _F.to_grayscale(inpt, num_output_channels=num_output_channels)
def rgb_to_grayscale(inpt: Any, num_output_channels: int = 1) -> Any:
def rgb_to_grayscale(
inpt: Union[PIL.Image.Image, torch.Tensor], num_output_channels: int = 1
) -> Union[PIL.Image.Image, torch.Tensor]:
old_color_space = features.Image.guess_color_space(inpt) if features.is_simple_tensor(inpt) else None
call = ", num_output_channels=3" if num_output_channels == 3 else ""
......
......@@ -1133,7 +1133,7 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
def ten_crop(inpt: DType, size: List[int], *, vertical_flip: bool = False) -> List[DType]:
def ten_crop(inpt: DType, size: List[int], vertical_flip: bool = False) -> List[DType]:
if isinstance(inpt, torch.Tensor):
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
if isinstance(inpt, features.Image):
......
......@@ -203,7 +203,11 @@ def convert_color_space_image_pil(
def convert_color_space(
inpt: Any, *, color_space: ColorSpace, old_color_space: Optional[ColorSpace] = None, copy: bool = True
inpt: Union[PIL.Image.Image, torch.Tensor, features._Feature],
*,
color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None,
copy: bool = True,
) -> Any:
if isinstance(inpt, Image):
return inpt.to_color_space(color_space, copy=copy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment