Unverified Commit 54dd0a59 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Prototype Transform cleanups and bugfixing (#6512)



* clean up bboxes

* Correcting bug on RandAugment.

* Fix bug on erase()

* Moving `_parse_pad_padding()` to functional geometry

* Fixing bug on deprecated Grayscale transforms

* Porting old bugfixes from main branch to _apply_image_transform

* Dropping mandatory keyword arguments to maintain BC.

* Adding antialias option where possible.

* Specifying types in `_transform()` where possible.

* Add todo.

* Fixing tests.

* Adding padding_mode in pad_bbox
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent 26099237
...@@ -1272,8 +1272,11 @@ class TestScaleJitter: ...@@ -1272,8 +1272,11 @@ class TestScaleJitter:
def test__transform(self, mocker): def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock() interpolation_sentinel = mocker.MagicMock()
antialias_sentinel = mocker.MagicMock()
transform = transforms.ScaleJitter(target_size=(16, 12), interpolation=interpolation_sentinel) transform = transforms.ScaleJitter(
target_size=(16, 12), interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
transform._transformed_types = (mocker.MagicMock,) transform._transformed_types = (mocker.MagicMock,)
size_sentinel = mocker.MagicMock() size_sentinel = mocker.MagicMock()
...@@ -1286,7 +1289,9 @@ class TestScaleJitter: ...@@ -1286,7 +1289,9 @@ class TestScaleJitter:
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
transform(inpt_sentinel) transform(inpt_sentinel)
mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel) mock.assert_called_once_with(
inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
class TestRandomShortestSize: class TestRandomShortestSize:
...@@ -1316,8 +1321,11 @@ class TestRandomShortestSize: ...@@ -1316,8 +1321,11 @@ class TestRandomShortestSize:
def test__transform(self, mocker): def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock() interpolation_sentinel = mocker.MagicMock()
antialias_sentinel = mocker.MagicMock()
transform = transforms.RandomShortestSize(min_size=[3, 5, 7], max_size=12, interpolation=interpolation_sentinel) transform = transforms.RandomShortestSize(
min_size=[3, 5, 7], max_size=12, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
transform._transformed_types = (mocker.MagicMock,) transform._transformed_types = (mocker.MagicMock,)
size_sentinel = mocker.MagicMock() size_sentinel = mocker.MagicMock()
...@@ -1331,7 +1339,9 @@ class TestRandomShortestSize: ...@@ -1331,7 +1339,9 @@ class TestRandomShortestSize:
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
transform(inpt_sentinel) transform(inpt_sentinel)
mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel) mock.assert_called_once_with(
inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
class TestSimpleCopyPaste: class TestSimpleCopyPaste:
...@@ -1404,6 +1414,9 @@ class TestSimpleCopyPaste: ...@@ -1404,6 +1414,9 @@ class TestSimpleCopyPaste:
masks[0, 3:9, 2:8] = 1 masks[0, 3:9, 2:8] = 1
masks[1, 20:30, 20:30] = 1 masks[1, 20:30, 20:30] = 1
labels = torch.tensor([1, 2]) labels = torch.tensor([1, 2])
blending = True
resize_interpolation = InterpolationMode.BILINEAR
antialias = None
if label_type == features.OneHotLabel: if label_type == features.OneHotLabel:
labels = torch.nn.functional.one_hot(labels, num_classes=5) labels = torch.nn.functional.one_hot(labels, num_classes=5)
target = { target = {
...@@ -1431,7 +1444,9 @@ class TestSimpleCopyPaste: ...@@ -1431,7 +1444,9 @@ class TestSimpleCopyPaste:
transform = transforms.SimpleCopyPaste() transform = transforms.SimpleCopyPaste()
random_selection = torch.tensor([0, 1]) random_selection = torch.tensor([0, 1])
output_image, output_target = transform._copy_paste(image, target, paste_image, paste_target, random_selection) output_image, output_target = transform._copy_paste(
image, target, paste_image, paste_target, random_selection, blending, resize_interpolation, antialias
)
assert output_image.unique().tolist() == [2, 10] assert output_image.unique().tolist() == [2, 10]
assert output_target["boxes"].shape == (4, 4) assert output_target["boxes"].shape == (4, 4)
......
import math import math
import numbers import numbers
import warnings import warnings
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -15,6 +15,8 @@ from ._utils import has_any, query_chw ...@@ -15,6 +15,8 @@ from ._utils import has_any, query_chw
class RandomErasing(_RandomApplyTransform): class RandomErasing(_RandomApplyTransform):
_transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image)
def __init__( def __init__(
self, self,
p: float = 0.5, p: float = 0.5,
...@@ -86,7 +88,9 @@ class RandomErasing(_RandomApplyTransform): ...@@ -86,7 +88,9 @@ class RandomErasing(_RandomApplyTransform):
return dict(i=i, j=j, h=h, w=w, v=v) return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(
self, inpt: Union[torch.Tensor, features.Image, PIL.Image.Image], params: Dict[str, Any]
) -> Union[torch.Tensor, features.Image, PIL.Image.Image]:
if params["v"] is not None: if params["v"] is not None:
inpt = F.erase(inpt, **params) inpt = F.erase(inpt, **params)
...@@ -94,7 +98,7 @@ class RandomErasing(_RandomApplyTransform): ...@@ -94,7 +98,7 @@ class RandomErasing(_RandomApplyTransform):
class _BaseMixupCutmix(_RandomApplyTransform): class _BaseMixupCutmix(_RandomApplyTransform):
def __init__(self, *, alpha: float, p: float = 0.5) -> None: def __init__(self, alpha: float, p: float = 0.5) -> None:
super().__init__(p=p) super().__init__(p=p)
self.alpha = alpha self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
...@@ -188,10 +192,12 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -188,10 +192,12 @@ class SimpleCopyPaste(_RandomApplyTransform):
p: float = 0.5, p: float = 0.5,
blending: bool = True, blending: bool = True,
resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR, resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> None: ) -> None:
super().__init__(p=p) super().__init__(p=p)
self.resize_interpolation = resize_interpolation self.resize_interpolation = resize_interpolation
self.blending = blending self.blending = blending
self.antialias = antialias
def _copy_paste( def _copy_paste(
self, self,
...@@ -200,8 +206,9 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -200,8 +206,9 @@ class SimpleCopyPaste(_RandomApplyTransform):
paste_image: Any, paste_image: Any,
paste_target: Dict[str, Any], paste_target: Dict[str, Any],
random_selection: torch.Tensor, random_selection: torch.Tensor,
blending: bool = True, blending: bool,
resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR, resize_interpolation: F.InterpolationMode,
antialias: Optional[bool],
) -> Tuple[Any, Dict[str, Any]]: ) -> Tuple[Any, Dict[str, Any]]:
paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection]) paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection])
...@@ -217,7 +224,7 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -217,7 +224,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
size1 = image.shape[-2:] size1 = image.shape[-2:]
size2 = paste_image.shape[-2:] size2 = paste_image.shape[-2:]
if size1 != size2: if size1 != size2:
paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation) paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation, antialias=antialias)
paste_masks = F.resize(paste_masks, size=size1) paste_masks = F.resize(paste_masks, size=size1)
paste_boxes = F.resize(paste_boxes, size=size1) paste_boxes = F.resize(paste_boxes, size=size1)
...@@ -356,6 +363,7 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -356,6 +363,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
random_selection=random_selection, random_selection=random_selection,
blending=self.blending, blending=self.blending,
resize_interpolation=self.resize_interpolation, resize_interpolation=self.resize_interpolation,
antialias=self.antialias,
) )
output_images.append(output_image) output_images.append(output_image)
output_targets.append(output_target) output_targets.append(output_target)
......
...@@ -116,8 +116,8 @@ class _AutoAugmentBase(Transform): ...@@ -116,8 +116,8 @@ class _AutoAugmentBase(Transform):
angle=0.0, angle=0.0,
translate=[int(magnitude), 0], translate=[int(magnitude), 0],
scale=1.0, scale=1.0,
shear=[0.0, 0.0],
interpolation=interpolation, interpolation=interpolation,
shear=[0.0, 0.0],
fill=fill_, fill=fill_,
) )
elif transform_id == "TranslateY": elif transform_id == "TranslateY":
...@@ -126,8 +126,8 @@ class _AutoAugmentBase(Transform): ...@@ -126,8 +126,8 @@ class _AutoAugmentBase(Transform):
angle=0.0, angle=0.0,
translate=[0, int(magnitude)], translate=[0, int(magnitude)],
scale=1.0, scale=1.0,
shear=[0.0, 0.0],
interpolation=interpolation, interpolation=interpolation,
shear=[0.0, 0.0],
fill=fill_, fill=fill_,
) )
elif transform_id == "Rotate": elif transform_id == "Rotate":
......
...@@ -112,7 +112,7 @@ class RandomPhotometricDistort(Transform): ...@@ -112,7 +112,7 @@ class RandomPhotometricDistort(Transform):
channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None, channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None,
) )
def _permute_channels(self, inpt: Any, *, permutation: torch.Tensor) -> Any: def _permute_channels(self, inpt: Any, permutation: torch.Tensor) -> Any:
if isinstance(inpt, PIL.Image.Image): if isinstance(inpt, PIL.Image.Image):
inpt = F.to_image_tensor(inpt) inpt = F.to_image_tensor(inpt)
...@@ -125,7 +125,9 @@ class RandomPhotometricDistort(Transform): ...@@ -125,7 +125,9 @@ class RandomPhotometricDistort(Transform):
return output return output
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(
self, inpt: Union[torch.Tensor, features.Image, PIL.Image.Image], params: Dict[str, Any]
) -> Union[torch.Tensor, features.Image, PIL.Image.Image]:
if params["brightness"]: if params["brightness"]:
inpt = F.adjust_brightness( inpt = F.adjust_brightness(
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1]) inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
......
...@@ -22,7 +22,7 @@ class Compose(Transform): ...@@ -22,7 +22,7 @@ class Compose(Transform):
class RandomApply(_RandomApplyTransform): class RandomApply(_RandomApplyTransform):
def __init__(self, transform: Transform, *, p: float = 0.5) -> None: def __init__(self, transform: Transform, p: float = 0.5) -> None:
super().__init__(p=p) super().__init__(p=p)
self.transform = transform self.transform = transform
......
import warnings import warnings
from typing import Any, Dict from typing import Any, Dict, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -14,6 +14,9 @@ from ._transform import _RandomApplyTransform ...@@ -14,6 +14,9 @@ from ._transform import _RandomApplyTransform
from ._utils import query_chw from ._utils import query_chw
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
class ToTensor(Transform): class ToTensor(Transform):
_transformed_types = (PIL.Image.Image, np.ndarray) _transformed_types = (PIL.Image.Image, np.ndarray)
...@@ -24,7 +27,7 @@ class ToTensor(Transform): ...@@ -24,7 +27,7 @@ class ToTensor(Transform):
) )
super().__init__() super().__init__()
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor:
return _F.to_tensor(inpt) return _F.to_tensor(inpt)
...@@ -52,8 +55,11 @@ class Grayscale(Transform): ...@@ -52,8 +55,11 @@ class Grayscale(Transform):
super().__init__() super().__init__()
self.num_output_channels = num_output_channels self.num_output_channels = num_output_channels
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: DType, params: Dict[str, Any]) -> DType:
return _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY)
return output
class RandomGrayscale(_RandomApplyTransform): class RandomGrayscale(_RandomApplyTransform):
...@@ -78,5 +84,8 @@ class RandomGrayscale(_RandomApplyTransform): ...@@ -78,5 +84,8 @@ class RandomGrayscale(_RandomApplyTransform):
num_input_channels, _, _ = query_chw(sample) num_input_channels, _, _ = query_chw(sample)
return dict(num_input_channels=num_input_channels) return dict(num_input_channels=num_input_channels)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: DType, params: Dict[str, Any]) -> DType:
return _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY)
return output
...@@ -15,6 +15,9 @@ from ._transform import _RandomApplyTransform ...@@ -15,6 +15,9 @@ from ._transform import _RandomApplyTransform
from ._utils import _check_sequence_input, _setup_angle, _setup_size, has_all, has_any, query_bounding_box, query_chw from ._utils import _check_sequence_input, _setup_angle, _setup_size, has_all, has_any, query_bounding_box, query_chw
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
class RandomHorizontalFlip(_RandomApplyTransform): class RandomHorizontalFlip(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.horizontal_flip(inpt) return F.horizontal_flip(inpt)
...@@ -163,7 +166,7 @@ class FiveCrop(Transform): ...@@ -163,7 +166,7 @@ class FiveCrop(Transform):
super().__init__() super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: DType, params: Dict[str, Any]) -> Tuple[DType, DType, DType, DType, DType]:
return F.five_crop(inpt, self.size) return F.five_crop(inpt, self.size)
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
...@@ -184,7 +187,7 @@ class TenCrop(Transform): ...@@ -184,7 +187,7 @@ class TenCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip self.vertical_flip = vertical_flip
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: DType, params: Dict[str, Any]) -> List[DType]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
...@@ -713,11 +716,13 @@ class ScaleJitter(Transform): ...@@ -713,11 +716,13 @@ class ScaleJitter(Transform):
target_size: Tuple[int, int], target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0), scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
): ):
super().__init__() super().__init__()
self.target_size = target_size self.target_size = target_size
self.scale_range = scale_range self.scale_range = scale_range
self.interpolation = interpolation self.interpolation = interpolation
self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
_, orig_height, orig_width = query_chw(sample) _, orig_height, orig_width = query_chw(sample)
...@@ -729,7 +734,7 @@ class ScaleJitter(Transform): ...@@ -729,7 +734,7 @@ class ScaleJitter(Transform):
return dict(size=(new_height, new_width)) return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation) return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias)
class RandomShortestSize(Transform): class RandomShortestSize(Transform):
...@@ -738,11 +743,13 @@ class RandomShortestSize(Transform): ...@@ -738,11 +743,13 @@ class RandomShortestSize(Transform):
min_size: Union[List[int], Tuple[int], int], min_size: Union[List[int], Tuple[int], int],
max_size: int, max_size: int,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
): ):
super().__init__() super().__init__()
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size) self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
self.max_size = max_size self.max_size = max_size
self.interpolation = interpolation self.interpolation = interpolation
self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
_, orig_height, orig_width = query_chw(sample) _, orig_height, orig_width = query_chw(sample)
...@@ -756,7 +763,7 @@ class RandomShortestSize(Transform): ...@@ -756,7 +763,7 @@ class RandomShortestSize(Transform):
return dict(size=(new_height, new_width)) return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation) return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias)
class FixedSizeCrop(Transform): class FixedSizeCrop(Transform):
......
...@@ -16,7 +16,7 @@ class ConvertBoundingBoxFormat(Transform): ...@@ -16,7 +16,7 @@ class ConvertBoundingBoxFormat(Transform):
format = features.BoundingBoxFormat[format] format = features.BoundingBoxFormat[format]
self.format = format self.format = format
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox:
output = F.convert_bounding_box_format(inpt, old_format=inpt.format, new_format=params["format"]) output = F.convert_bounding_box_format(inpt, old_format=inpt.format, new_format=params["format"])
return features.BoundingBox.new_like(inpt, output, format=params["format"]) return features.BoundingBox.new_like(inpt, output, format=params["format"])
...@@ -28,9 +28,11 @@ class ConvertImageDtype(Transform): ...@@ -28,9 +28,11 @@ class ConvertImageDtype(Transform):
super().__init__() super().__init__()
self.dtype = dtype self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(
self, inpt: Union[torch.Tensor, features.Image], params: Dict[str, Any]
) -> Union[torch.Tensor, features.Image]:
output = F.convert_image_dtype(inpt, dtype=self.dtype) output = F.convert_image_dtype(inpt, dtype=self.dtype)
return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype) return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype) # type: ignore[arg-type]
class ConvertColorSpace(Transform): class ConvertColorSpace(Transform):
...@@ -54,7 +56,9 @@ class ConvertColorSpace(Transform): ...@@ -54,7 +56,9 @@ class ConvertColorSpace(Transform):
self.copy = copy self.copy = copy
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, features._Feature], params: Dict[str, Any]
) -> Union[torch.Tensor, PIL.Image.Image, features._Feature]:
return F.convert_color_space( return F.convert_color_space(
inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy
) )
...@@ -63,6 +67,6 @@ class ConvertColorSpace(Transform): ...@@ -63,6 +67,6 @@ class ConvertColorSpace(Transform):
class ClampBoundingBoxes(Transform): class ClampBoundingBoxes(Transform):
_transformed_types = (features.BoundingBox,) _transformed_types = (features.BoundingBox,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox:
output = F.clamp_bounding_box(inpt, format=inpt.format, image_size=inpt.image_size) output = F.clamp_bounding_box(inpt, format=inpt.format, image_size=inpt.image_size)
return features.BoundingBox.new_like(inpt, output) return features.BoundingBox.new_like(inpt, output)
...@@ -68,7 +68,7 @@ class LinearTransformation(Transform): ...@@ -68,7 +68,7 @@ class LinearTransformation(Transform):
return super().forward(*inputs) return super().forward(*inputs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: def _transform(self, inpt: Union[torch.Tensor, features._Feature], params: Dict[str, Any]) -> torch.Tensor:
# Image instance after linear transformation is not Image anymore due to unknown data range # Image instance after linear transformation is not Image anymore due to unknown data range
# Thus we will return Tensor for input Image # Thus we will return Tensor for input Image
...@@ -100,7 +100,7 @@ class Normalize(Transform): ...@@ -100,7 +100,7 @@ class Normalize(Transform):
self.mean = list(mean) self.mean = list(mean)
self.std = list(std) self.std = list(std)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Union[torch.Tensor, features._Feature], params: Dict[str, Any]) -> torch.Tensor:
return F.normalize(inpt, mean=self.mean, std=self.std) return F.normalize(inpt, mean=self.mean, std=self.std)
def forward(self, *inpts: Any) -> Any: def forward(self, *inpts: Any) -> Any:
......
...@@ -56,7 +56,7 @@ class Transform(nn.Module): ...@@ -56,7 +56,7 @@ class Transform(nn.Module):
class _RandomApplyTransform(Transform): class _RandomApplyTransform(Transform):
def __init__(self, *, p: float = 0.5) -> None: def __init__(self, p: float = 0.5) -> None:
if not (0.0 <= p <= 1.0): if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].") raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
......
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torchvision.prototype import features from torchvision.prototype import features
...@@ -11,7 +12,7 @@ from torchvision.prototype.transforms import functional as F, Transform ...@@ -11,7 +12,7 @@ from torchvision.prototype.transforms import functional as F, Transform
class DecodeImage(Transform): class DecodeImage(Transform):
_transformed_types = (features.EncodedImage,) _transformed_types = (features.EncodedImage,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image: def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> features.Image:
return F.decode_image_with_pil(inpt) return F.decode_image_with_pil(inpt)
...@@ -39,18 +40,22 @@ class LabelToOneHot(Transform): ...@@ -39,18 +40,22 @@ class LabelToOneHot(Transform):
class ToImageTensor(Transform): class ToImageTensor(Transform):
_transformed_types = (features.is_simple_tensor, PIL.Image.Image, np.ndarray) _transformed_types = (features.is_simple_tensor, PIL.Image.Image, np.ndarray)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image: def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> features.Image:
return F.to_image_tensor(inpt) return F.to_image_tensor(inpt)
class ToImagePIL(Transform): class ToImagePIL(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, np.ndarray) _transformed_types = (features.is_simple_tensor, features.Image, np.ndarray)
def __init__(self, *, mode: Optional[str] = None) -> None: def __init__(self, mode: Optional[str] = None) -> None:
super().__init__() super().__init__()
self.mode = mode self.mode = mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image.Image: def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> PIL.Image.Image:
return F.to_image_pil(inpt, mode=self.mode) return F.to_image_pil(inpt, mode=self.mode)
......
# TODO: Add _log_api_usage_once() in all mid-level kernels. If they remain not jit-scriptable we can use decorators
from torchvision.transforms import InterpolationMode # usort: skip from torchvision.transforms import InterpolationMode # usort: skip
from ._meta import ( from ._meta import (
clamp_bounding_box, clamp_bounding_box,
......
from typing import Any from typing import Union
import PIL.Image import PIL.Image
...@@ -19,7 +19,15 @@ def erase_image_pil( ...@@ -19,7 +19,15 @@ def erase_image_pil(
return to_pil_image(output, mode=img.mode) return to_pil_image(output, mode=img.mode)
def erase(inpt: Any, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False) -> Any: def erase(
inpt: Union[torch.Tensor, PIL.Image.Image, features.Image],
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> Union[torch.Tensor, PIL.Image.Image, features.Image]:
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
if isinstance(inpt, features.Image): if isinstance(inpt, features.Image):
......
import warnings import warnings
from typing import Any from typing import Any, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -21,7 +21,9 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima ...@@ -21,7 +21,9 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
return _F.to_grayscale(inpt, num_output_channels=num_output_channels) return _F.to_grayscale(inpt, num_output_channels=num_output_channels)
def rgb_to_grayscale(inpt: Any, num_output_channels: int = 1) -> Any: def rgb_to_grayscale(
inpt: Union[PIL.Image.Image, torch.Tensor], num_output_channels: int = 1
) -> Union[PIL.Image.Image, torch.Tensor]:
old_color_space = features.Image.guess_color_space(inpt) if features.is_simple_tensor(inpt) else None old_color_space = features.Image.guess_color_space(inpt) if features.is_simple_tensor(inpt) else None
call = ", num_output_channels=3" if num_output_channels == 3 else "" call = ", num_output_channels=3" if num_output_channels == 3 else ""
......
...@@ -1133,7 +1133,7 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo ...@@ -1133,7 +1133,7 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
def ten_crop(inpt: DType, size: List[int], *, vertical_flip: bool = False) -> List[DType]: def ten_crop(inpt: DType, size: List[int], vertical_flip: bool = False) -> List[DType]:
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
if isinstance(inpt, features.Image): if isinstance(inpt, features.Image):
......
...@@ -203,7 +203,11 @@ def convert_color_space_image_pil( ...@@ -203,7 +203,11 @@ def convert_color_space_image_pil(
def convert_color_space( def convert_color_space(
inpt: Any, *, color_space: ColorSpace, old_color_space: Optional[ColorSpace] = None, copy: bool = True inpt: Union[PIL.Image.Image, torch.Tensor, features._Feature],
*,
color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None,
copy: bool = True,
) -> Any: ) -> Any:
if isinstance(inpt, Image): if isinstance(inpt, Image):
return inpt.to_color_space(color_space, copy=copy) return inpt.to_color_space(color_space, copy=copy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment