Unverified Commit 7251769f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Transforms without dispatcher (#5421)



* add prototype transforms that don't need dispatchers

* cleanup

* remove legacy_transform decorator

* remove legacy classes

* remove explicit param passing

* streamline extra_repr

* remove obsolete ._supports() method

* cleanup

* remove Query

* cleanup

* fix tests

* kernels -> functional

* move image size and num channels extraction to functional

* extend legacy function to extract image size and num channels

* implement dispatching for auto augment

* fix auto augment dispatch

* revert some naming changes

* remove ability to pass params to autoaugment

* fix legacy image size extraction

* align prototype.transforms.functional with transforms.functional

* cleanup

* fix image size and channels extraction

* fix affine and rotate

* revert image size to (width, height)

* Minor corrections
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent f15ba56f
import itertools
import PIL.Image
import pytest
import torch
from test_prototype_transforms_kernels import make_images, make_bounding_boxes, make_one_hot_labels
from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels
from torchvision.prototype import transforms, features
from torchvision.transforms.functional import to_pil_image
......@@ -25,15 +24,6 @@ def make_vanilla_tensor_bounding_boxes(*args, **kwargs):
yield bounding_box.data
INPUT_CREATIONS_FNS = {
features.Image: make_images,
features.BoundingBox: make_bounding_boxes,
features.OneHotLabel: make_one_hot_labels,
torch.Tensor: make_vanilla_tensor_images,
PIL.Image.Image: make_pil_images,
}
def parametrize(transforms_with_inputs):
return pytest.mark.parametrize(
("transform", "input"),
......@@ -52,15 +42,21 @@ def parametrize(transforms_with_inputs):
def parametrize_from_transforms(*transforms):
transforms_with_inputs = []
for transform in transforms:
dispatcher = transform._DISPATCHER
if dispatcher is None:
continue
for type_ in dispatcher._kernels:
for creation_fn in [
make_images,
make_bounding_boxes,
make_one_hot_labels,
make_vanilla_tensor_images,
make_pil_images,
]:
inputs = list(creation_fn())
try:
inputs = INPUT_CREATIONS_FNS[type_]()
except KeyError:
output = transform(inputs[0])
except Exception:
continue
else:
if output is inputs[0]:
continue
transforms_with_inputs.append((transform, inputs))
......@@ -69,7 +65,7 @@ def parametrize_from_transforms(*transforms):
class TestSmoke:
@parametrize_from_transforms(
transforms.RandomErasing(),
transforms.RandomErasing(p=1.0),
transforms.HorizontalFlip(),
transforms.Resize([16, 16]),
transforms.CenterCrop([16, 16]),
......@@ -141,35 +137,6 @@ class TestSmoke:
def test_normalize(self, transform, input):
transform(input)
@parametrize(
[
(
transforms.ConvertColorSpace("grayscale"),
itertools.chain(
make_images(),
make_vanilla_tensor_images(color_spaces=["rgb"]),
make_pil_images(color_spaces=["rgb"]),
),
)
]
)
def test_convert_bounding_color_space(self, transform, input):
transform(input)
@parametrize(
[
(
transforms.ConvertBoundingBoxFormat("xyxy", old_format="xywh"),
itertools.chain(
make_bounding_boxes(),
make_vanilla_tensor_bounding_boxes(formats=["xywh"]),
),
)
]
)
def test_convert_bounding_box_format(self, transform, input):
transform(input)
@parametrize(
[
(
......
......@@ -3,7 +3,7 @@ import itertools
import pytest
import torch.testing
import torchvision.prototype.transforms.kernels as K
import torchvision.prototype.transforms.functional as F
from torch import jit
from torch.nn.functional import one_hot
from torchvision.prototype import features
......@@ -134,10 +134,10 @@ class SampleInput:
self.kwargs = kwargs
class KernelInfo:
class FunctionalInfo:
def __init__(self, name, *, sample_inputs_fn):
self.name = name
self.kernel = getattr(K, name)
self.functional = getattr(F, name)
self._sample_inputs_fn = sample_inputs_fn
def sample_inputs(self):
......@@ -146,21 +146,21 @@ class KernelInfo:
def __call__(self, *args, **kwargs):
if len(args) == 1 and not kwargs and isinstance(args[0], SampleInput):
sample_input = args[0]
return self.kernel(*sample_input.args, **sample_input.kwargs)
return self.functional(*sample_input.args, **sample_input.kwargs)
return self.kernel(*args, **kwargs)
return self.functional(*args, **kwargs)
KERNEL_INFOS = []
FUNCTIONAL_INFOS = []
def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn):
KERNEL_INFOS.append(KernelInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn))
FUNCTIONAL_INFOS.append(FunctionalInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn))
return sample_inputs_fn
@register_kernel_info_from_sample_inputs_fn
def horizontal_flip_image():
def horizontal_flip_image_tensor():
for image in make_images():
yield SampleInput(image)
......@@ -172,12 +172,12 @@ def horizontal_flip_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def resize_image():
def resize_image_tensor():
for image, interpolation in itertools.product(
make_images(),
[
K.InterpolationMode.BILINEAR,
K.InterpolationMode.NEAREST,
F.InterpolationMode.BILINEAR,
F.InterpolationMode.NEAREST,
],
):
height, width = image.shape[-2:]
......@@ -200,20 +200,20 @@ def resize_bounding_box():
class TestKernelsCommon:
@pytest.mark.parametrize("kernel_info", KERNEL_INFOS, ids=lambda kernel_info: kernel_info.name)
def test_scriptable(self, kernel_info):
jit.script(kernel_info.kernel)
@pytest.mark.parametrize("functional_info", FUNCTIONAL_INFOS, ids=lambda functional_info: functional_info.name)
def test_scriptable(self, functional_info):
jit.script(functional_info.functional)
@pytest.mark.parametrize(
("kernel_info", "sample_input"),
("functional_info", "sample_input"),
[
pytest.param(kernel_info, sample_input, id=f"{kernel_info.name}-{idx}")
for kernel_info in KERNEL_INFOS
for idx, sample_input in enumerate(kernel_info.sample_inputs())
pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}")
for functional_info in FUNCTIONAL_INFOS
for idx, sample_input in enumerate(functional_info.sample_inputs())
],
)
def test_eager_vs_scripted(self, kernel_info, sample_input):
eager = kernel_info(sample_input)
scripted = jit.script(kernel_info.kernel)(*sample_input.args, **sample_input.kwargs)
def test_eager_vs_scripted(self, functional_info, sample_input):
eager = functional_info(sample_input)
scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs)
torch.testing.assert_close(eager, scripted)
......@@ -41,7 +41,7 @@ class BoundingBox(_Feature):
# promote this out of the prototype state
# import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.kernels import convert_bounding_box_format
from torchvision.prototype.transforms.functional import convert_bounding_box_format
if isinstance(format, str):
format = BoundingBoxFormat[format]
......
......@@ -43,7 +43,7 @@ class EncodedImage(EncodedData):
# promote this out of the prototype state
# import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.kernels import decode_image_with_pil
from torchvision.prototype.transforms.functional import decode_image_with_pil
return Image(decode_image_with_pil(self))
......
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
from . import kernels # usort: skip
from torchvision.transforms import InterpolationMode, AutoAugmentPolicy # usort: skip
from . import functional # usort: skip
from ._transform import Transform # usort: skip
from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertColorSpace
from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
from ._type_conversion import DecodeImage, LabelToOneHot
......@@ -3,7 +3,6 @@ import numbers
import warnings
from typing import Any, Dict, Tuple
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
......@@ -12,9 +11,6 @@ from ._utils import query_image
class RandomErasing(Transform):
_DISPATCHER = F.erase
_FAIL_TYPES = {PIL.Image.Image, features.BoundingBox, features.SegmentationMask}
def __init__(
self,
p: float = 0.5,
......@@ -45,8 +41,8 @@ class RandomErasing(Transform):
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
img_h, img_w = F.get_image_size(image)
img_c = F.get_image_num_channels(image)
img_w, img_h = F.get_image_size(image)
if isinstance(self.value, (int, float)):
value = [self.value]
......@@ -93,16 +89,24 @@ class RandomErasing(Transform):
return dict(zip("ijhwv", (i, j, h, w, v)))
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if torch.rand(1) >= self.p:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.erase_image_tensor(input, **params)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
return F.erase_image_tensor(input, **params)
else:
return input
return super()._transform(input, params)
def forward(self, *inputs: Any) -> Any:
if torch.rand(1) >= self.p:
return inputs if len(inputs) > 1 else inputs[0]
return super().forward(*inputs)
class RandomMixup(Transform):
_DISPATCHER = F.mixup
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
def __init__(self, *, alpha: float) -> None:
super().__init__()
self.alpha = alpha
......@@ -111,11 +115,20 @@ class RandomMixup(Transform):
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(())))
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.mixup_image_tensor(input, **params)
return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel):
output = F.mixup_one_hot_label(input, **params)
return features.OneHotLabel.new_like(input, output)
else:
return input
class RandomCutmix(Transform):
_DISPATCHER = F.cutmix
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
class RandomCutmix(Transform):
def __init__(self, *, alpha: float) -> None:
super().__init__()
self.alpha = alpha
......@@ -125,7 +138,7 @@ class RandomCutmix(Transform):
lam = float(self._dist.sample(()))
image = query_image(sample)
H, W = F.get_image_size(image)
W, H = F.get_image_size(image)
r_x = torch.randint(W, ())
r_y = torch.randint(H, ())
......@@ -143,3 +156,15 @@ class RandomCutmix(Transform):
lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.cutmix_image_tensor(input, box=params["box"])
return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel):
output = F.cutmix_one_hot_label(input, lam_adjusted=params["lam_adjusted"])
return features.OneHotLabel.new_like(input, output)
else:
return input
......@@ -21,65 +21,31 @@ class _AutoAugmentBase(Transform):
self.interpolation = interpolation
self.fill = fill
_DISPATCHER_MAP: Dict[str, Callable[[Any, float, InterpolationMode, Optional[List[float]]], Any]] = {
"Identity": lambda input, magnitude, interpolation, fill: input,
"ShearX": lambda input, magnitude, interpolation, fill: F.affine(
input,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[math.degrees(magnitude), 0.0],
interpolation=interpolation,
fill=fill,
),
"ShearY": lambda input, magnitude, interpolation, fill: F.affine(
input,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[0.0, math.degrees(magnitude)],
interpolation=interpolation,
fill=fill,
),
"TranslateX": lambda input, magnitude, interpolation, fill: F.affine(
input,
angle=0.0,
translate=[int(magnitude), 0],
scale=1.0,
shear=[0.0, 0.0],
interpolation=interpolation,
fill=fill,
),
"TranslateY": lambda input, magnitude, interpolation, fill: F.affine(
input,
angle=0.0,
translate=[0, int(magnitude)],
scale=1.0,
shear=[0.0, 0.0],
interpolation=interpolation,
fill=fill,
),
"Rotate": lambda input, magnitude, interpolation, fill: F.rotate(input, angle=magnitude),
"Brightness": lambda input, magnitude, interpolation, fill: F.adjust_brightness(
input, brightness_factor=1.0 + magnitude
),
"Color": lambda input, magnitude, interpolation, fill: F.adjust_saturation(
input, saturation_factor=1.0 + magnitude
),
"Contrast": lambda input, magnitude, interpolation, fill: F.adjust_contrast(
input, contrast_factor=1.0 + magnitude
),
"Sharpness": lambda input, magnitude, interpolation, fill: F.adjust_sharpness(
input, sharpness_factor=1.0 + magnitude
),
"Posterize": lambda input, magnitude, interpolation, fill: F.posterize(input, bits=int(magnitude)),
"Solarize": lambda input, magnitude, interpolation, fill: F.solarize(input, threshold=magnitude),
"AutoContrast": lambda input, magnitude, interpolation, fill: F.autocontrast(input),
"Equalize": lambda input, magnitude, interpolation, fill: F.equalize(input),
"Invert": lambda input, magnitude, interpolation, fill: F.invert(input),
}
def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
keys = tuple(dct.keys())
key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]
def _apply_transform(self, sample: Any, transform_id: str, magnitude: float) -> Any:
def dispatch(
image_tensor_kernel: Callable,
image_pil_kernel: Callable,
input: Any,
*args: Any,
**kwargs: Any,
) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = image_tensor_kernel(input, *args, **kwargs)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
return image_tensor_kernel(input, *args, **kwargs)
elif isinstance(input, PIL.Image.Image):
return image_pil_kernel(input, *args, **kwargs)
else:
return input
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
num_channels = F.get_image_num_channels(image)
......@@ -89,24 +55,104 @@ class _AutoAugmentBase(Transform):
elif fill is not None:
fill = [float(f) for f in fill]
return dict(interpolation=self.interpolation, fill=fill)
def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
keys = tuple(dct.keys())
key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]
def _apply_transform(self, sample: Any, params: Dict[str, Any], transform_id: str, magnitude: float) -> Any:
dispatcher = self._DISPATCHER_MAP[transform_id]
interpolation = self.interpolation
def transform(input: Any) -> Any:
if type(input) in {features.Image, torch.Tensor} or isinstance(input, PIL.Image.Image):
return dispatcher(input, magnitude, params["interpolation"], params["fill"])
elif type(input) in {features.BoundingBox, features.SegmentationMask}:
if type(input) in {features.BoundingBox, features.SegmentationMask}:
raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()")
else:
elif not (type(input) in {features.Image, torch.Tensor} or isinstance(input, PIL.Image.Image)):
return input
if transform_id == "Identity":
return input
elif transform_id == "ShearX":
return dispatch(
F.affine_image_tensor,
F.affine_image_pil,
input,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[math.degrees(magnitude), 0.0],
interpolation=interpolation,
fill=fill,
)
elif transform_id == "ShearY":
return dispatch(
F.affine_image_tensor,
F.affine_image_pil,
input,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[0.0, math.degrees(magnitude)],
interpolation=interpolation,
fill=fill,
)
elif transform_id == "TranslateX":
return dispatch(
F.affine_image_tensor,
F.affine_image_pil,
input,
angle=0.0,
translate=[int(magnitude), 0],
scale=1.0,
shear=[0.0, 0.0],
interpolation=interpolation,
fill=fill,
)
elif transform_id == "TranslateY":
return dispatch(
F.affine_image_tensor,
F.affine_image_pil,
input,
angle=0.0,
translate=[0, int(magnitude)],
scale=1.0,
shear=[0.0, 0.0],
interpolation=interpolation,
fill=fill,
)
elif transform_id == "Rotate":
return dispatch(F.rotate_image_tensor, F.rotate_image_pil, input, angle=magnitude)
elif transform_id == "Brightness":
return dispatch(
F.adjust_brightness_image_tensor,
F.adjust_brightness_image_pil,
input,
brightness_factor=1.0 + magnitude,
)
elif transform_id == "Color":
return dispatch(
F.adjust_saturation_image_tensor,
F.adjust_saturation_image_pil,
input,
saturation_factor=1.0 + magnitude,
)
elif transform_id == "Contrast":
return dispatch(
F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, input, contrast_factor=1.0 + magnitude
)
elif transform_id == "Sharpness":
return dispatch(
F.adjust_sharpness_image_tensor,
F.adjust_sharpness_image_pil,
input,
sharpness_factor=1.0 + magnitude,
)
elif transform_id == "Posterize":
return dispatch(F.posterize_image_tensor, F.posterize_image_pil, input, bits=int(magnitude))
elif transform_id == "Solarize":
return dispatch(F.solarize_image_tensor, F.solarize_image_pil, input, threshold=magnitude)
elif transform_id == "AutoContrast":
return dispatch(F.autocontrast_image_tensor, F.autocontrast_image_pil, input)
elif transform_id == "Equalize":
return dispatch(F.equalize_image_tensor, F.equalize_image_pil, input)
elif transform_id == "Invert":
return dispatch(F.invert_image_tensor, F.invert_image_pil, input)
else:
raise ValueError(f"No transform available for {transform_id}")
return apply_recursively(transform, sample)
......@@ -114,7 +160,7 @@ class AutoAugment(_AutoAugmentBase):
_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
......@@ -228,9 +274,8 @@ class AutoAugment(_AutoAugmentBase):
else:
raise ValueError(f"The provided policy {policy} is not recognized.")
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
params = params or self._get_params(sample)
image = query_image(sample)
image_size = F.get_image_size(image)
......@@ -251,7 +296,7 @@ class AutoAugment(_AutoAugmentBase):
else:
magnitude = 0.0
sample = self._apply_transform(sample, params, transform_id, magnitude)
sample = self._apply_transform(sample, transform_id, magnitude)
return sample
......@@ -261,7 +306,7 @@ class RandAugment(_AutoAugmentBase):
"Identity": (lambda num_bins, image_size: None, False),
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
......@@ -285,9 +330,8 @@ class RandAugment(_AutoAugmentBase):
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
params = params or self._get_params(sample)
image = query_image(sample)
image_size = F.get_image_size(image)
......@@ -303,7 +347,7 @@ class RandAugment(_AutoAugmentBase):
else:
magnitude = 0.0
sample = self._apply_transform(sample, params, transform_id, magnitude)
sample = self._apply_transform(sample, transform_id, magnitude)
return sample
......@@ -335,9 +379,8 @@ class TrivialAugmentWide(_AutoAugmentBase):
super().__init__(**kwargs)
self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
params = params or self._get_params(sample)
image = query_image(sample)
image_size = F.get_image_size(image)
......@@ -352,4 +395,4 @@ class TrivialAugmentWide(_AutoAugmentBase):
else:
magnitude = 0.0
return self._apply_transform(sample, params, transform_id, magnitude)
return self._apply_transform(sample, transform_id, magnitude)
from typing import Any, Optional, Dict
from typing import Any
import torch
......@@ -6,13 +6,13 @@ from ._transform import Transform
class Compose(Transform):
def __init__(self, *transforms: Transform):
def __init__(self, *transforms: Transform) -> None:
super().__init__()
self.transforms = transforms
for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform)
def forward(self, *inputs: Any) -> Any: # type: ignore[override]
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for transform in self.transforms:
sample = transform(sample)
......@@ -25,38 +25,38 @@ class RandomApply(Transform):
self.transform = transform
self.p = p
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if float(torch.rand(())) < self.p:
return sample
return self.transform(sample, params=params)
return self.transform(sample)
def extra_repr(self) -> str:
return f"p={self.p}"
class RandomChoice(Transform):
def __init__(self, *transforms: Transform):
def __init__(self, *transforms: Transform) -> None:
super().__init__()
self.transforms = transforms
for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform)
def forward(self, *inputs: Any) -> Any: # type: ignore[override]
def forward(self, *inputs: Any) -> Any:
idx = int(torch.randint(len(self.transforms), size=()))
transform = self.transforms[idx]
return transform(*inputs)
class RandomOrder(Transform):
def __init__(self, *transforms: Transform):
def __init__(self, *transforms: Transform) -> None:
super().__init__()
self.transforms = transforms
for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform)
def forward(self, *inputs: Any) -> Any: # type: ignore[override]
def forward(self, *inputs: Any) -> Any:
for idx in torch.randperm(len(self.transforms)):
transform = self.transforms[idx]
inputs = transform(*inputs)
......
......@@ -2,6 +2,7 @@ import math
import warnings
from typing import Any, Dict, List, Union, Sequence, Tuple, cast
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
......@@ -11,41 +12,69 @@ from ._utils import query_image
class HorizontalFlip(Transform):
_DISPATCHER = F.horizontal_flip
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.horizontal_flip_image_tensor(input)
return features.Image.new_like(input, output)
elif isinstance(input, features.BoundingBox):
output = F.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size)
return features.BoundingBox.new_like(input, output)
elif isinstance(input, PIL.Image.Image):
return F.horizontal_flip_image_pil(input)
elif isinstance(input, torch.Tensor):
return F.horizontal_flip_image_tensor(input)
else:
return input
class Resize(Transform):
_DISPATCHER = F.resize
def __init__(
self,
size: Union[int, Sequence[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self.size = size
self.size = [size] if isinstance(size, int) else list(size)
self.interpolation = interpolation
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(size=self.size, interpolation=self.interpolation)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.resize_image_tensor(input, self.size, interpolation=self.interpolation)
return features.Image.new_like(input, output)
elif isinstance(input, features.SegmentationMask):
output = F.resize_segmentation_mask(input, self.size)
return features.SegmentationMask.new_like(input, output)
elif isinstance(input, features.BoundingBox):
output = F.resize_bounding_box(input, self.size, image_size=input.image_size)
return features.BoundingBox.new_like(input, output, image_size=self.size)
elif isinstance(input, PIL.Image.Image):
return F.resize_image_pil(input, self.size, interpolation=self.interpolation)
elif isinstance(input, torch.Tensor):
return F.resize_image_tensor(input, self.size, interpolation=self.interpolation)
else:
return input
class CenterCrop(Transform):
_DISPATCHER = F.center_crop
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
def __init__(self, output_size: List[int]):
super().__init__()
self.output_size = output_size
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(output_size=self.output_size)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.center_crop_image_tensor(input, self.output_size)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
return F.center_crop_image_tensor(input, self.output_size)
elif isinstance(input, PIL.Image.Image):
return F.center_crop_image_pil(input, self.output_size)
else:
return input
class RandomResizedCrop(Transform):
_DISPATCHER = F.resized_crop
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
def __init__(
self,
size: Union[int, Sequence[int]],
......@@ -80,7 +109,7 @@ class RandomResizedCrop(Transform):
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
height, width = F.get_image_size(image)
width, height = F.get_image_size(image)
area = height * width
log_ratio = torch.log(torch.tensor(self.ratio))
......@@ -115,4 +144,19 @@ class RandomResizedCrop(Transform):
i = (height - h) // 2
j = (width - w) // 2
return dict(top=i, left=j, height=h, width=w, size=self.size)
return dict(top=i, left=j, height=h, width=w)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.resized_crop_image_tensor(
input, **params, size=list(self.size), interpolation=self.interpolation
)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
return F.resized_crop_image_tensor(input, **params, size=list(self.size), interpolation=self.interpolation)
elif isinstance(input, PIL.Image.Image):
return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation)
else:
return input
from typing import Union, Any, Dict, Optional
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
......@@ -7,24 +8,18 @@ from torchvision.transforms.functional import convert_image_dtype
class ConvertBoundingBoxFormat(Transform):
_DISPATCHER = F.convert_format
def __init__(
self,
format: Union[str, features.BoundingBoxFormat],
old_format: Optional[Union[str, features.BoundingBoxFormat]] = None,
) -> None:
def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None:
super().__init__()
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
self.format = format
if isinstance(old_format, str):
old_format = features.BoundingBoxFormat[old_format]
self.old_format = old_format
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(format=self.format, old_format=self.old_format)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if type(input) is features.BoundingBox:
output = F.convert_bounding_box_format(input, old_format=input.format, new_format=params["format"])
return features.BoundingBox.new_like(input, output, format=params["format"])
else:
return input
class ConvertImageDtype(Transform):
......@@ -33,21 +28,50 @@ class ConvertImageDtype(Transform):
self.dtype = dtype
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not isinstance(input, features.Image):
if type(input) is features.Image:
output = convert_image_dtype(input, dtype=self.dtype)
return features.Image.new_like(input, output, dtype=self.dtype)
else:
return input
output = convert_image_dtype(input, dtype=self.dtype)
return features.Image.new_like(input, output, dtype=self.dtype)
class ConvertColorSpace(Transform):
_DISPATCHER = F.convert_color_space
def __init__(self, color_space: Union[str, features.ColorSpace]) -> None:
class ConvertImageColorSpace(Transform):
def __init__(
self,
color_space: Union[str, features.ColorSpace],
old_color_space: Optional[Union[str, features.ColorSpace]] = None,
) -> None:
super().__init__()
if isinstance(color_space, str):
color_space = features.ColorSpace[color_space]
self.color_space = color_space
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(color_space=self.color_space)
if isinstance(old_color_space, str):
old_color_space = features.ColorSpace[old_color_space]
self.old_color_space = old_color_space
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.convert_image_color_space_tensor(
input, old_color_space=input.color_space, new_color_space=self.color_space
)
return features.Image.new_like(input, output, color_space=self.color_space)
elif isinstance(input, torch.Tensor):
if self.old_color_space is None:
raise RuntimeError("")
return F.convert_image_color_space_tensor(
input, old_color_space=self.old_color_space, new_color_space=self.color_space
)
elif isinstance(input, PIL.Image.Image):
old_color_space = {
"L": features.ColorSpace.GRAYSCALE,
"RGB": features.ColorSpace.RGB,
}.get(input.mode, features.ColorSpace.OTHER)
return F.convert_image_color_space_pil(
input, old_color_space=old_color_space, new_color_space=self.color_space
)
else:
return input
......@@ -17,11 +17,11 @@ class Lambda(Transform):
self.types = types
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not isinstance(input, self.types):
if type(input) in self.types:
return self.fn(input)
else:
return input
return self.fn(input)
def extra_repr(self) -> str:
extras = []
name = getattr(self.fn, "__name__", None)
......@@ -32,15 +32,18 @@ class Lambda(Transform):
class Normalize(Transform):
_DISPATCHER = F.normalize
def __init__(self, mean: List[float], std: List[float]):
super().__init__()
self.mean = mean
self.std = std
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(mean=self.mean, std=self.std)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, torch.Tensor):
# We don't need to differentiate between vanilla tensors and features.Image's here, since the result of the
# normalization transform is no longer a features.Image
return F.normalize_image_tensor(input, mean=self.mean, std=self.std)
else:
return input
class ToDtype(Lambda):
......
import enum
import functools
from typing import Any, Dict, Optional, Set, Type
from typing import Any, Dict
from torch import nn
from torchvision.prototype.utils._internal import apply_recursively
from torchvision.utils import _log_api_usage_once
from .functional._utils import Dispatcher
class Transform(nn.Module):
_DISPATCHER: Optional[Dispatcher] = None
_FAIL_TYPES: Set[Type] = set()
def __init__(self) -> None:
super().__init__()
_log_api_usage_once(self)
......@@ -21,19 +16,11 @@ class Transform(nn.Module):
return dict()
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not self._DISPATCHER:
raise NotImplementedError()
if input in self._DISPATCHER:
return self._DISPATCHER(input, **params)
elif type(input) in self._FAIL_TYPES:
raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()")
else:
return input
raise NotImplementedError
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
return apply_recursively(functools.partial(self._transform, params=params or self._get_params(sample)), sample)
return apply_recursively(functools.partial(self._transform, params=self._get_params(sample)), sample)
def extra_repr(self) -> str:
extra = []
......
from typing import Any, Dict
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, kernels as K
from torchvision.prototype.transforms import Transform, functional as F
class DecodeImage(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not isinstance(input, features.EncodedImage):
if type(input) is features.EncodedImage:
output = F.decode_image_with_pil(input)
return features.Image(output)
else:
return input
return features.Image(K.decode_image_with_pil(input))
class LabelToOneHot(Transform):
def __init__(self, num_categories: int = -1):
......@@ -18,16 +19,15 @@ class LabelToOneHot(Transform):
self.num_categories = num_categories
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not isinstance(input, features.Label):
if type(input) is features.Label:
num_categories = self.num_categories
if num_categories == -1 and input.categories is not None:
num_categories = len(input.categories)
output = F.label_to_one_hot(input, num_categories=num_categories)
return features.OneHotLabel(output, categories=input.categories)
else:
return input
num_categories = self.num_categories
if num_categories == -1 and input.categories is not None:
num_categories = len(input.categories)
return features.OneHotLabel(
K.label_to_one_hot(input, num_categories=num_categories), categories=input.categories
)
def extra_repr(self) -> str:
if self.num_categories == -1:
return ""
......
from typing import Any, Union, Optional
from typing import Any, Optional, Union
import PIL.Image
import torch
......
from ._augment import erase, mixup, cutmix
from torchvision.transforms import InterpolationMode # usort: skip
from ._utils import get_image_size, get_image_num_channels # usort: skip
from ._meta_conversion import (
convert_bounding_box_format,
convert_image_color_space_tensor,
convert_image_color_space_pil,
) # usort: skip
from ._augment import (
erase_image_tensor,
mixup_image_tensor,
mixup_one_hot_label,
cutmix_image_tensor,
cutmix_one_hot_label,
)
from ._color import (
adjust_brightness,
adjust_contrast,
adjust_saturation,
adjust_sharpness,
posterize,
solarize,
autocontrast,
equalize,
invert,
adjust_brightness_image_tensor,
adjust_brightness_image_pil,
adjust_contrast_image_tensor,
adjust_contrast_image_pil,
adjust_saturation_image_tensor,
adjust_saturation_image_pil,
adjust_sharpness_image_tensor,
adjust_sharpness_image_pil,
posterize_image_tensor,
posterize_image_pil,
solarize_image_tensor,
solarize_image_pil,
autocontrast_image_tensor,
autocontrast_image_pil,
equalize_image_tensor,
equalize_image_pil,
invert_image_tensor,
invert_image_pil,
adjust_hue_image_tensor,
adjust_hue_image_pil,
adjust_gamma_image_tensor,
adjust_gamma_image_pil,
)
from ._geometry import (
horizontal_flip_bounding_box,
horizontal_flip_image_tensor,
horizontal_flip_image_pil,
resize_bounding_box,
resize_image_tensor,
resize_image_pil,
resize_segmentation_mask,
center_crop_image_tensor,
center_crop_image_pil,
resized_crop_image_tensor,
resized_crop_image_pil,
affine_image_tensor,
affine_image_pil,
rotate_image_tensor,
rotate_image_pil,
pad_image_tensor,
pad_image_pil,
crop_image_tensor,
crop_image_pil,
perspective_image_tensor,
perspective_image_pil,
vertical_flip_image_tensor,
vertical_flip_image_pil,
)
from ._geometry import horizontal_flip, resize, center_crop, resized_crop, affine, rotate
from ._meta_conversion import convert_color_space, convert_format
from ._misc import normalize, get_image_size, get_image_num_channels
from ._misc import normalize_image_tensor, gaussian_blur_image_tensor
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
from typing import Any
from typing import Tuple
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import kernels as K
from torchvision.transforms import functional as _F
from ._utils import dispatch
@dispatch(
{
torch.Tensor: _F.erase,
features.Image: K.erase_image,
}
)
def erase(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
features.Image: K.mixup_image,
features.OneHotLabel: K.mixup_one_hot_label,
}
)
def mixup(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
features.Image: None,
features.OneHotLabel: None,
}
)
def cutmix(input: Any, *args: Any, **kwargs: Any) -> Any:
"""Perform the CutMix operation as introduced in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" <https://arxiv.org/abs/1905.04899>`_.
Dispatch to the corresponding kernels happens according to this table:
.. table::
:widths: 30 70
==================================================== ================================================================
:class:`~torchvision.prototype.features.Image` :func:`~torch.prototype.transforms.kernels.cutmix_image`
:class:`~torchvision.prototype.features.OneHotLabel` :func:`~torch.prototype.transforms.kernels.cutmix_one_hot_label`
==================================================== ================================================================
Please refer to the kernel documentations for a detailed explanation of the functionality and parameters.
"""
if isinstance(input, features.Image):
kwargs.pop("lam_adjusted", None)
output = K.cutmix_image(input, **kwargs)
return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel):
kwargs.pop("box", None)
output = K.cutmix_one_hot_label(input, **kwargs)
return features.OneHotLabel.new_like(input, output)
raise RuntimeError
from torchvision.transforms import functional_tensor as _FT
erase_image_tensor = _FT.erase
def _mixup_tensor(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor:
input = input.clone()
return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam))
def mixup_image_tensor(image_batch: torch.Tensor, *, lam: float) -> torch.Tensor:
if image_batch.ndim < 4:
raise ValueError("Need a batch of images")
return _mixup_tensor(image_batch, -4, lam)
def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float) -> torch.Tensor:
if one_hot_label_batch.ndim < 2:
raise ValueError("Need a batch of one hot labels")
return _mixup_tensor(one_hot_label_batch, -2, lam)
def cutmix_image_tensor(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int]) -> torch.Tensor:
if image_batch.ndim < 4:
raise ValueError("Need a batch of images")
x1, y1, x2, y2 = box
image_rolled = image_batch.roll(1, -4)
image_batch = image_batch.clone()
image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
return image_batch
def cutmix_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam_adjusted: float) -> torch.Tensor:
if one_hot_label_batch.ndim < 2:
raise ValueError("Need a batch of one hot labels")
return _mixup_tensor(one_hot_label_batch, -2, lam_adjusted)
from typing import Any
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import kernels as K
from torchvision.transforms import functional as _F
adjust_brightness_image_tensor = _FT.adjust_brightness
adjust_brightness_image_pil = _FP.adjust_brightness
from ._utils import dispatch
adjust_saturation_image_tensor = _FT.adjust_saturation
adjust_saturation_image_pil = _FP.adjust_saturation
adjust_contrast_image_tensor = _FT.adjust_contrast
adjust_contrast_image_pil = _FP.adjust_contrast
@dispatch(
{
torch.Tensor: _F.adjust_brightness,
PIL.Image.Image: _F.adjust_brightness,
features.Image: K.adjust_brightness_image,
}
)
def adjust_brightness(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
adjust_sharpness_image_tensor = _FT.adjust_sharpness
adjust_sharpness_image_pil = _FP.adjust_sharpness
posterize_image_tensor = _FT.posterize
posterize_image_pil = _FP.posterize
@dispatch(
{
torch.Tensor: _F.adjust_saturation,
PIL.Image.Image: _F.adjust_saturation,
features.Image: K.adjust_saturation_image,
}
)
def adjust_saturation(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
solarize_image_tensor = _FT.solarize
solarize_image_pil = _FP.solarize
autocontrast_image_tensor = _FT.autocontrast
autocontrast_image_pil = _FP.autocontrast
@dispatch(
{
torch.Tensor: _F.adjust_contrast,
PIL.Image.Image: _F.adjust_contrast,
features.Image: K.adjust_contrast_image,
}
)
def adjust_contrast(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
equalize_image_tensor = _FT.equalize
equalize_image_pil = _FP.equalize
invert_image_tensor = _FT.invert
invert_image_pil = _FP.invert
@dispatch(
{
torch.Tensor: _F.adjust_sharpness,
PIL.Image.Image: _F.adjust_sharpness,
features.Image: K.adjust_sharpness_image,
}
)
def adjust_sharpness(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
adjust_hue_image_tensor = _FT.adjust_hue
adjust_hue_image_pil = _FP.adjust_hue
@dispatch(
{
torch.Tensor: _F.posterize,
PIL.Image.Image: _F.posterize,
features.Image: K.posterize_image,
}
)
def posterize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.solarize,
PIL.Image.Image: _F.solarize,
features.Image: K.solarize_image,
}
)
def solarize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.autocontrast,
PIL.Image.Image: _F.autocontrast,
features.Image: K.autocontrast_image,
}
)
def autocontrast(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.equalize,
PIL.Image.Image: _F.equalize,
features.Image: K.equalize_image,
}
)
def equalize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.invert,
PIL.Image.Image: _F.invert,
features.Image: K.invert_image,
}
)
def invert(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.adjust_hue,
PIL.Image.Image: _F.adjust_hue,
features.Image: K.adjust_hue_image,
}
)
def adjust_hue(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.adjust_gamma,
PIL.Image.Image: _F.adjust_gamma,
features.Image: K.adjust_gamma_image,
}
)
def adjust_gamma(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
adjust_gamma_image_tensor = _FT.adjust_gamma
adjust_gamma_image_pil = _FP.adjust_gamma
from typing import Any
import numbers
from typing import Tuple, List, Optional, Sequence, Union
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import kernels as K
from torchvision.transforms import functional as _F
from ._utils import dispatch
@dispatch(
{
torch.Tensor: _F.hflip,
PIL.Image.Image: _F.hflip,
features.Image: K.horizontal_flip_image,
features.BoundingBox: None,
},
)
def horizontal_flip(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
if isinstance(input, features.BoundingBox):
output = K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size)
return features.BoundingBox.new_like(input, output)
raise RuntimeError
@dispatch(
{
torch.Tensor: _F.resize,
PIL.Image.Image: _F.resize,
features.Image: K.resize_image,
features.SegmentationMask: K.resize_segmentation_mask,
features.BoundingBox: None,
}
)
def resize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
if isinstance(input, features.BoundingBox):
size = kwargs.pop("size")
output = K.resize_bounding_box(input, size=size, image_size=input.image_size)
return features.BoundingBox.new_like(input, output, image_size=size)
raise RuntimeError
@dispatch(
{
torch.Tensor: _F.center_crop,
PIL.Image.Image: _F.center_crop,
features.Image: K.center_crop_image,
}
)
def center_crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.resized_crop,
PIL.Image.Image: _F.resized_crop,
features.Image: K.resized_crop_image,
}
)
def resized_crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.affine,
PIL.Image.Image: _F.affine,
features.Image: K.affine_image,
}
)
def affine(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.rotate,
PIL.Image.Image: _F.rotate,
features.Image: K.rotate_image,
}
)
def rotate(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.pad,
PIL.Image.Image: _F.pad,
features.Image: K.pad_image,
}
)
def pad(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.crop,
PIL.Image.Image: _F.crop,
features.Image: K.crop_image,
}
)
def crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.perspective,
PIL.Image.Image: _F.perspective,
features.Image: K.perspective_image,
}
)
def perspective(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.vflip,
PIL.Image.Image: _F.vflip,
features.Image: K.vertical_flip_image,
}
)
def vertical_flip(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.five_crop,
PIL.Image.Image: _F.five_crop,
features.Image: K.five_crop_image,
}
)
def five_crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.ten_crop,
PIL.Image.Image: _F.ten_crop,
features.Image: K.ten_crop_image,
}
)
def ten_crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
from torchvision.prototype.transforms import InterpolationMode
from torchvision.prototype.transforms.functional import get_image_size
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix
from ._meta_conversion import convert_bounding_box_format
horizontal_flip_image_tensor = _FT.hflip
horizontal_flip_image_pil = _FP.hflip
def horizontal_flip_bounding_box(
bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int]
) -> torch.Tensor:
shape = bounding_box.shape
bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [2, 0]]
return convert_bounding_box_format(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
).view(shape)
def resize_image_tensor(
image: torch.Tensor,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> torch.Tensor:
new_height, new_width = size
old_width, old_height = _FT.get_image_size(image)
num_channels = _FT.get_image_num_channels(image)
batch_shape = image.shape[:-3]
return _FT.resize(
image.reshape((-1, num_channels, old_height, old_width)),
size=size,
interpolation=interpolation.value,
max_size=max_size,
antialias=antialias,
).reshape(batch_shape + (num_channels, new_height, new_width))
def resize_image_pil(
img: PIL.Image.Image,
size: Union[Sequence[int], int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
) -> PIL.Image.Image:
return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation], max_size=max_size)
def resize_segmentation_mask(
segmentation_mask: torch.Tensor, size: List[int], max_size: Optional[int] = None
) -> torch.Tensor:
return resize_image_tensor(segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
# TODO: handle max_size
def resize_bounding_box(bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor:
old_height, old_width = image_size
new_height, new_width = size
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape)
vertical_flip_image_tensor = _FT.vflip
vertical_flip_image_pil = _FP.vflip
def _affine_parse_args(
angle: float,
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
center: Optional[List[float]] = None,
) -> Tuple[float, List[float], List[float], Optional[List[float]]]:
if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float")
if not isinstance(translate, (list, tuple)):
raise TypeError("Argument translate should be a sequence")
if len(translate) != 2:
raise ValueError("Argument translate should be a sequence of length 2")
if scale <= 0.0:
raise ValueError("Argument scale should be positive")
if not isinstance(shear, (numbers.Number, (list, tuple))):
raise TypeError("Shear should be either a single value or a sequence of two values")
if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")
if isinstance(angle, int):
angle = float(angle)
if isinstance(translate, tuple):
translate = list(translate)
if isinstance(shear, numbers.Number):
shear = [shear, 0.0]
if isinstance(shear, tuple):
shear = list(shear)
if len(shear) == 1:
shear = [shear[0], shear[0]]
if len(shear) != 2:
raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
if center is not None and not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence")
return angle, translate, shear, center
def affine_image_tensor(
img: torch.Tensor,
angle: float,
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
center_f = [0.0, 0.0]
if center is not None:
width, height = get_image_size(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]
translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
return _FT.affine(img, matrix, interpolation=interpolation.value, fill=fill)
def affine_image_pil(
img: PIL.Image.Image,
angle: float,
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> PIL.Image.Image:
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None:
width, height = get_image_size(img)
center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
return _FP.affine(img, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill, center=center)
def rotate_image_tensor(
img: torch.Tensor,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
center_f = [0.0, 0.0]
if center is not None:
width, height = get_image_size(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]
# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
return _FT.rotate(img, matrix, interpolation=interpolation.value, expand=expand, fill=fill)
def rotate_image_pil(
img: PIL.Image.Image,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> PIL.Image.Image:
return _FP.rotate(
img, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
)
pad_image_tensor = _FT.pad
pad_image_pil = _FP.pad
crop_image_tensor = _FT.crop
crop_image_pil = _FP.crop
def perspective_image_tensor(
img: torch.Tensor,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> torch.Tensor:
return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill)
def perspective_image_pil(
img: PIL.Image.Image,
perspective_coeffs: float,
interpolation: InterpolationMode = InterpolationMode.BICUBIC,
fill: Optional[List[float]] = None,
) -> PIL.Image.Image:
return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
if isinstance(output_size, numbers.Number):
return [int(output_size), int(output_size)]
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
return [output_size[0], output_size[0]]
else:
return list(output_size)
def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]:
return [
(crop_width - image_width) // 2 if crop_width > image_width else 0,
(crop_height - image_height) // 2 if crop_height > image_height else 0,
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
def _center_crop_compute_crop_anchor(
crop_height: int, crop_width: int, image_height: int, image_width: int
) -> Tuple[int, int]:
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return crop_top, crop_left
def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_width, image_height = get_image_size(img)
if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
img = pad_image_tensor(img, padding_ltrb, fill=0)
image_width, image_height = get_image_size(img)
if crop_width == image_width and crop_height == image_height:
return img
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
return crop_image_tensor(img, crop_top, crop_left, crop_height, crop_width)
def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_width, image_height = get_image_size(img)
if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
img = pad_image_pil(img, padding_ltrb, fill=0)
image_width, image_height = get_image_size(img)
if crop_width == image_width and crop_height == image_height:
return img
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
return crop_image_pil(img, crop_top, crop_left, crop_height, crop_width)
def resized_crop_image_tensor(
img: torch.Tensor,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> torch.Tensor:
img = crop_image_tensor(img, top, left, height, width)
return resize_image_tensor(img, size, interpolation=interpolation)
def resized_crop_image_pil(
img: PIL.Image.Image,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> PIL.Image.Image:
img = crop_image_pil(img, top, left, height, width)
return resize_image_pil(img, size, interpolation=interpolation)
from typing import Any
import PIL.Image
import torch
from torchvision.ops import box_convert
from torchvision.prototype import features
from torchvision.prototype.transforms import kernels as K
from torchvision.transforms import functional as _F
from ._utils import dispatch
@dispatch(
{
torch.Tensor: None,
features.BoundingBox: None,
}
)
def convert_format(input: Any, *args: Any, **kwargs: Any) -> Any:
format = kwargs["format"]
if type(input) is torch.Tensor:
old_format = kwargs.get("old_format")
if old_format is None:
raise TypeError("For vanilla tensors the `old_format` needs to be provided.")
return box_convert(input, in_fmt=kwargs["old_format"].name.lower(), out_fmt=format.name.lower())
elif isinstance(input, features.BoundingBox):
output = K.convert_bounding_box_format(input, old_format=input.format, new_format=kwargs["format"])
return features.BoundingBox.new_like(input, output, format=format)
raise RuntimeError
@dispatch(
{
torch.Tensor: None,
PIL.Image.Image: None,
features.Image: None,
}
)
def convert_color_space(input: Any, *args: Any, **kwargs: Any) -> Any:
color_space = kwargs["color_space"]
if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image):
if color_space != features.ColorSpace.GRAYSCALE:
raise ValueError("For vanilla tensors and PIL images only RGB to grayscale is supported")
return _F.rgb_to_grayscale(input)
elif isinstance(input, features.Image):
output = K.convert_color_space(input, old_color_space=input.color_space, new_color_space=color_space)
return features.Image.new_like(input, output, color_space=color_space)
raise RuntimeError
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:
xyxy = xywh.clone()
xyxy[..., 2:] += xyxy[..., :2]
return xyxy
def _xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
xywh = xyxy.clone()
xywh[..., 2:] -= xywh[..., :2]
return xywh
def _cxcywh_to_xyxy(cxcywh: torch.Tensor) -> torch.Tensor:
cx, cy, w, h = torch.unbind(cxcywh, dim=-1)
x1 = cx - 0.5 * w
y1 = cy - 0.5 * h
x2 = cx + 0.5 * w
y2 = cy + 0.5 * h
return torch.stack((x1, y1, x2, y2), dim=-1)
def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor:
x1, y1, x2, y2 = torch.unbind(xyxy, dim=-1)
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
return torch.stack((cx, cy, w, h), dim=-1)
def convert_bounding_box_format(
bounding_box: torch.Tensor, *, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat
) -> torch.Tensor:
if new_format == old_format:
return bounding_box.clone()
if old_format == BoundingBoxFormat.XYWH:
bounding_box = _xywh_to_xyxy(bounding_box)
elif old_format == BoundingBoxFormat.CXCYWH:
bounding_box = _cxcywh_to_xyxy(bounding_box)
if new_format == BoundingBoxFormat.XYWH:
bounding_box = _xyxy_to_xywh(bounding_box)
elif new_format == BoundingBoxFormat.CXCYWH:
bounding_box = _xyxy_to_cxcywh(bounding_box)
return bounding_box
def _grayscale_to_rgb_tensor(grayscale: torch.Tensor) -> torch.Tensor:
return grayscale.expand(3, 1, 1)
def convert_image_color_space_tensor(
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace
) -> torch.Tensor:
if new_color_space == old_color_space:
return image.clone()
if old_color_space == ColorSpace.GRAYSCALE:
image = _grayscale_to_rgb_tensor(image)
if new_color_space == ColorSpace.GRAYSCALE:
image = _FT.rgb_to_grayscale(image)
return image
def _grayscale_to_rgb_pil(grayscale: PIL.Image.Image) -> PIL.Image.Image:
return grayscale.convert("RGB")
def convert_image_color_space_pil(
image: PIL.Image.Image, old_color_space: ColorSpace, new_color_space: ColorSpace
) -> PIL.Image.Image:
if new_color_space == old_color_space:
return image.copy()
if old_color_space == ColorSpace.GRAYSCALE:
image = _grayscale_to_rgb_pil(image)
if new_color_space == ColorSpace.GRAYSCALE:
image = _FP.to_grayscale(image)
return image
from typing import Any
from typing import Optional, List
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import kernels as K
from torchvision.transforms import functional as _F
from torchvision.transforms.functional_pil import (
get_image_size as _get_image_size_pil,
get_image_num_channels as _get_image_num_channels_pil,
)
from torchvision.transforms.functional_tensor import (
get_image_size as _get_image_size_tensor,
get_image_num_channels as _get_image_num_channels_tensor,
)
from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import to_tensor, to_pil_image
from ._utils import dispatch
normalize_image_tensor = _FT.normalize
@dispatch(
{
torch.Tensor: _F.normalize,
features.Image: K.normalize_image,
}
)
def normalize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
def gaussian_blur_image_tensor(
img: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor:
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
if len(kernel_size) != 2:
raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
for ksize in kernel_size:
if ksize % 2 == 0 or ksize < 0:
raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
@dispatch(
{
torch.Tensor: _F.gaussian_blur,
PIL.Image.Image: _F.gaussian_blur,
features.Image: K.gaussian_blur_image,
}
)
def gaussian_blur(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
if sigma is None:
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
if isinstance(sigma, (int, float)):
sigma = [float(sigma), float(sigma)]
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
sigma = [sigma[0], sigma[0]]
if len(sigma) != 2:
raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}")
for s in sigma:
if s <= 0.0:
raise ValueError(f"sigma should have positive values. Got {sigma}")
@dispatch(
{
torch.Tensor: _get_image_size_tensor,
PIL.Image.Image: _get_image_size_pil,
features.Image: None,
features.BoundingBox: None,
}
)
def get_image_size(input: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(input, (features.Image, features.BoundingBox)):
return list(input.image_size)
return _FT.gaussian_blur(img, kernel_size, sigma)
raise RuntimeError
@dispatch(
{
torch.Tensor: _get_image_num_channels_tensor,
PIL.Image.Image: _get_image_num_channels_pil,
features.Image: None,
}
)
def get_image_num_channels(input: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(input, features.Image):
return input.num_channels
raise RuntimeError
def gaussian_blur_image_pil(img: PIL.Image, kernel_size: List[int], sigma: Optional[List[float]] = None) -> PIL.Image:
return to_pil_image(gaussian_blur_image_tensor(to_tensor(img), kernel_size=kernel_size, sigma=sigma))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment