Unverified Commit 52e6bd08 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add prototype transforms that use the prototype dispatchers (#5418)

* add prototype transforms that use the prototype dispatchers

Conflicts:
	torchvision/prototype/transforms/__init__.py

* simplify

* add logger

* remove legacy classes

Conflicts:
	torchvision/prototype/transforms/_augment.py
	torchvision/prototype/transforms/_auto_augment.py
	torchvision/prototype/transforms/_geometry.py

* make get_params private

* remove randbool method

* remove AutoAugmentDispatcher

* add high level kernels for meta conversion

* remove transforms meta abstraction from auto augment transforms

* appease mypy

* add smoke tests for transforms

* remove Query object

* remove extra_repr helper

* fix tests

* appease mypy

* revert some changes on the kernel tests

* fix dispatcher annotations

* remove float cast for torch.rand

* add helper to query image

* fix imports

* address auto augment comments

* cleanup
parent 144f0980
...@@ -99,7 +99,6 @@ class TestCommon: ...@@ -99,7 +99,6 @@ class TestCommon:
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors." f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
) )
@pytest.mark.xfail
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, test_home, dataset_mock, config): def test_transformable(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config) dataset_mock.prepare(test_home, config)
......
import itertools
import PIL.Image
import pytest
import torch
from test_prototype_transforms_kernels import make_images, make_bounding_boxes, make_one_hot_labels
from torchvision.prototype import transforms, features
from torchvision.transforms.functional import to_pil_image
def make_vanilla_tensor_images(*args, **kwargs):
for image in make_images(*args, **kwargs):
if image.ndim > 3:
continue
yield image.data
def make_pil_images(*args, **kwargs):
for image in make_vanilla_tensor_images(*args, **kwargs):
yield to_pil_image(image)
def make_vanilla_tensor_bounding_boxes(*args, **kwargs):
for bounding_box in make_bounding_boxes(*args, **kwargs):
yield bounding_box.data
INPUT_CREATIONS_FNS = {
features.Image: make_images,
features.BoundingBox: make_bounding_boxes,
features.OneHotLabel: make_one_hot_labels,
torch.Tensor: make_vanilla_tensor_images,
PIL.Image.Image: make_pil_images,
}
def parametrize(transforms_with_inputs):
return pytest.mark.parametrize(
("transform", "input"),
[
pytest.param(
transform,
input,
id=f"{type(transform).__name__}-{type(input).__module__}.{type(input).__name__}-{idx}",
)
for transform, inputs in transforms_with_inputs
for idx, input in enumerate(inputs)
],
)
def parametrize_from_transforms(*transforms):
transforms_with_inputs = []
for transform in transforms:
dispatcher = transform._DISPATCHER
if dispatcher is None:
continue
for type_ in dispatcher._kernels:
try:
inputs = INPUT_CREATIONS_FNS[type_]()
except KeyError:
continue
transforms_with_inputs.append((transform, inputs))
return parametrize(transforms_with_inputs)
class TestSmoke:
@parametrize_from_transforms(
transforms.RandomErasing(),
transforms.HorizontalFlip(),
transforms.Resize([16, 16]),
transforms.CenterCrop([16, 16]),
transforms.ConvertImageDtype(),
)
def test_common(self, transform, input):
transform(input)
@parametrize(
[
(
transform,
[
dict(
image=features.Image.new_like(image, image.unsqueeze(0), dtype=torch.float),
one_hot_label=features.OneHotLabel.new_like(
one_hot_label, one_hot_label.unsqueeze(0), dtype=torch.float
),
)
for image, one_hot_label in itertools.product(make_images(), make_one_hot_labels())
],
)
for transform in [
transforms.RandomMixup(alpha=1.0),
transforms.RandomCutmix(alpha=1.0),
]
]
)
def test_mixup_cutmix(self, transform, input):
transform(input)
@parametrize(
[
(
transform,
itertools.chain.from_iterable(
fn(dtypes=[torch.uint8], extra_dims=[(4,)])
for fn in [
make_images,
make_vanilla_tensor_images,
make_pil_images,
]
),
)
for transform in (
transforms.RandAugment(),
transforms.TrivialAugmentWide(),
transforms.AutoAugment(),
)
]
)
def test_auto_augment(self, transform, input):
transform(input)
@parametrize(
[
(
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
itertools.chain.from_iterable(
fn(color_spaces=["rgb"], dtypes=[torch.float32])
for fn in [
make_images,
make_vanilla_tensor_images,
]
),
),
]
)
def test_normalize(self, transform, input):
transform(input)
@parametrize(
[
(
transforms.ConvertColorSpace("grayscale"),
itertools.chain(
make_images(),
make_vanilla_tensor_images(color_spaces=["rgb"]),
make_pil_images(color_spaces=["rgb"]),
),
)
]
)
def test_convert_bounding_color_space(self, transform, input):
transform(input)
@parametrize(
[
(
transforms.ConvertBoundingBoxFormat("xyxy", old_format="xywh"),
itertools.chain(
make_bounding_boxes(),
make_vanilla_tensor_bounding_boxes(formats=["xywh"]),
),
)
]
)
def test_convert_bounding_box_format(self, transform, input):
transform(input)
@parametrize(
[
(
transforms.RandomResizedCrop([16, 16]),
itertools.chain(
make_images(extra_dims=[(4,)]),
make_vanilla_tensor_images(),
make_pil_images(),
),
)
]
)
def test_random_resized_crop(self, transform, input):
transform(input)
...@@ -5,6 +5,7 @@ import pytest ...@@ -5,6 +5,7 @@ import pytest
import torch.testing import torch.testing
import torchvision.prototype.transforms.kernels as K import torchvision.prototype.transforms.kernels as K
from torch import jit from torch import jit
from torch.nn.functional import one_hot
from torchvision.prototype import features from torchvision.prototype import features
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu") make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
...@@ -39,10 +40,10 @@ def make_images( ...@@ -39,10 +40,10 @@ def make_images(
extra_dims=((4,), (2, 3)), extra_dims=((4,), (2, 3)),
): ):
for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes): for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes):
yield make_image(size, color_space=color_space) yield make_image(size, color_space=color_space, dtype=dtype)
for color_space, extra_dims_ in itertools.product(color_spaces, extra_dims): for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims):
yield make_image(color_space=color_space, extra_dims=extra_dims_) yield make_image(color_space=color_space, extra_dims=extra_dims_, dtype=dtype)
def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
...@@ -106,6 +107,27 @@ def make_bounding_boxes( ...@@ -106,6 +107,27 @@ def make_bounding_boxes(
yield make_bounding_box(format=format, extra_dims=extra_dims_) yield make_bounding_box(format=format, extra_dims=extra_dims_)
def make_label(size=(), *, categories=("category0", "category1")):
return features.Label(torch.randint(0, len(categories) if categories else 10, size), categories=categories)
def make_one_hot_label(*args, **kwargs):
label = make_label(*args, **kwargs)
return features.OneHotLabel(one_hot(label, num_classes=len(label.categories)), categories=label.categories)
def make_one_hot_labels(
*,
num_categories=(1, 2, 10),
extra_dims=((4,), (2, 3)),
):
for num_categories_ in num_categories:
yield make_one_hot_label(categories=[f"category{idx}" for idx in range(num_categories_)])
for extra_dims_ in extra_dims:
yield make_one_hot_label(extra_dims_)
class SampleInput: class SampleInput:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.args = args self.args = args
......
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
from . import kernels # usort: skip from . import kernels # usort: skip
from . import functional # usort: skip from . import functional # usort: skip
from .kernels import InterpolationMode # usort: skip from ._transform import Transform # usort: skip
from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
from ._type_conversion import DecodeImage, LabelToOneHot
import math
import numbers
import warnings
from typing import Any, Dict, Tuple
import torch
from torchvision.prototype.transforms import Transform, functional as F
from ._utils import query_image
class RandomErasing(Transform):
_DISPATCHER = F.erase
def __init__(
self,
p: float = 0.5,
scale: Tuple[float, float] = (0.02, 0.33),
ratio: Tuple[float, float] = (0.3, 3.3),
value: float = 0,
):
super().__init__()
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
raise ValueError("If value is str, it should be 'random'")
if not isinstance(scale, (tuple, list)):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, (tuple, list)):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("Scale should be between 0 and 1")
if p < 0 or p > 1:
raise ValueError("Random erasing probability should be between 0 and 1")
# TODO: deprecate p in favor of wrapping the transform in a RandomApply
self.p = p
self.scale = scale
self.ratio = ratio
self.value = value
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
img_h, img_w = F.get_image_size(image)
img_c = F.get_image_num_channels(image)
if isinstance(self.value, (int, float)):
value = [self.value]
elif isinstance(self.value, str):
value = None
elif isinstance(self.value, tuple):
value = list(self.value)
else:
value = self.value
if value is not None and not (len(value) in (1, img_c)):
raise ValueError(
f"If value is a sequence, it should have either a single value or {img_c} (number of input channels)"
)
area = img_h * img_w
log_ratio = torch.log(torch.tensor(self.ratio))
for _ in range(10):
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
)
).item()
h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio)))
if not (h < img_h and w < img_w):
continue
if value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else:
v = torch.tensor(value)[:, None, None]
i = torch.randint(0, img_h - h + 1, size=(1,)).item()
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
break
else:
i, j, h, w, v = 0, 0, img_h, img_w, image
return dict(zip("ijhwv", (i, j, h, w, v)))
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if torch.rand(1) >= self.p:
return input
return super()._transform(input, params)
class RandomMixup(Transform):
_DISPATCHER = F.mixup
def __init__(self, *, alpha: float) -> None:
super().__init__()
self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(())))
class RandomCutmix(Transform):
_DISPATCHER = F.cutmix
def __init__(self, *, alpha: float) -> None:
super().__init__()
self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def _get_params(self, sample: Any) -> Dict[str, Any]:
lam = float(self._dist.sample(()))
image = query_image(sample)
H, W = F.get_image_size(image)
r_x = torch.randint(W, ())
r_y = torch.randint(H, ())
r = 0.5 * math.sqrt(1.0 - lam)
r_w_half = int(r * W)
r_h_half = int(r * H)
x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))
box = (x1, y1, x2, y2)
lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
return dict(box=box, lam_adjusted=lam_adjusted)
import math
from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F
from torchvision.prototype.utils._internal import apply_recursively
from ._utils import query_image
K = TypeVar("K")
V = TypeVar("V")
class _AutoAugmentBase(Transform):
def __init__(
self, *, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None
) -> None:
super().__init__()
self.interpolation = interpolation
self.fill = fill
_DISPATCHER_MAP: Dict[str, Callable[[Any, float, InterpolationMode, Optional[List[float]]], Any]] = {
"Identity": lambda input, magnitude, interpolation, fill: input,
"ShearX": lambda input, magnitude, interpolation, fill: F.affine(
input,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[math.degrees(magnitude), 0.0],
interpolation=interpolation,
fill=fill,
),
"ShearY": lambda input, magnitude, interpolation, fill: F.affine(
input,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[0.0, math.degrees(magnitude)],
interpolation=interpolation,
fill=fill,
),
"TranslateX": lambda input, magnitude, interpolation, fill: F.affine(
input,
angle=0.0,
translate=[int(magnitude), 0],
scale=1.0,
shear=[0.0, 0.0],
interpolation=interpolation,
fill=fill,
),
"TranslateY": lambda input, magnitude, interpolation, fill: F.affine(
input,
angle=0.0,
translate=[0, int(magnitude)],
scale=1.0,
shear=[0.0, 0.0],
interpolation=interpolation,
fill=fill,
),
"Rotate": lambda input, magnitude, interpolation, fill: F.rotate(input, angle=magnitude),
"Brightness": lambda input, magnitude, interpolation, fill: F.adjust_brightness(
input, brightness_factor=1.0 + magnitude
),
"Color": lambda input, magnitude, interpolation, fill: F.adjust_saturation(
input, saturation_factor=1.0 + magnitude
),
"Contrast": lambda input, magnitude, interpolation, fill: F.adjust_contrast(
input, contrast_factor=1.0 + magnitude
),
"Sharpness": lambda input, magnitude, interpolation, fill: F.adjust_sharpness(
input, sharpness_factor=1.0 + magnitude
),
"Posterize": lambda input, magnitude, interpolation, fill: F.posterize(input, bits=int(magnitude)),
"Solarize": lambda input, magnitude, interpolation, fill: F.solarize(input, threshold=magnitude),
"AutoContrast": lambda input, magnitude, interpolation, fill: F.autocontrast(input),
"Equalize": lambda input, magnitude, interpolation, fill: F.equalize(input),
"Invert": lambda input, magnitude, interpolation, fill: F.invert(input),
}
def _is_supported(self, obj: Any) -> bool:
return type(obj) in {features.Image, torch.Tensor} or isinstance(obj, PIL.Image.Image)
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
num_channels = F.get_image_num_channels(image)
fill = self.fill
if isinstance(fill, (int, float)):
fill = [float(fill)] * num_channels
elif fill is not None:
fill = [float(f) for f in fill]
return dict(interpolation=self.interpolation, fill=fill)
def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
keys = tuple(dct.keys())
key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]
def _apply_transform(self, sample: Any, params: Dict[str, Any], transform_id: str, magnitude: float) -> Any:
dispatcher = self._DISPATCHER_MAP[transform_id]
def transform(input: Any) -> Any:
if not self._is_supported(input):
return input
return dispatcher(input, magnitude, params["interpolation"], params["fill"])
return apply_recursively(transform, sample)
class AutoAugment(_AutoAugmentBase):
_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (
lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round()
.int(),
False,
),
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, image_size: None, False),
"Equalize": (lambda num_bins, image_size: None, False),
"Invert": (lambda num_bins, image_size: None, False),
}
def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.policy = policy
self._policies = self._get_policies(policy)
def _get_policies(
self, policy: AutoAugmentPolicy
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
if policy == AutoAugmentPolicy.IMAGENET:
return [
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
]
elif policy == AutoAugmentPolicy.CIFAR10:
return [
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
]
elif policy == AutoAugmentPolicy.SVHN:
return [
(("ShearX", 0.9, 4), ("Invert", 0.2, None)),
(("ShearY", 0.9, 8), ("Invert", 0.7, None)),
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
(("ShearY", 0.9, 8), ("Invert", 0.4, None)),
(("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
(("ShearY", 0.8, 8), ("Invert", 0.7, None)),
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
(("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
(("Invert", 0.6, None), ("Rotate", 0.8, 4)),
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
(("ShearX", 0.1, 6), ("Invert", 0.6, None)),
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
(("ShearY", 0.8, 4), ("Invert", 0.8, None)),
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
]
else:
raise ValueError(f"The provided policy {policy} is not recognized.")
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
params = params or self._get_params(sample)
image = query_image(sample)
image_size = F.get_image_size(image)
policy = self._policies[int(torch.randint(len(self._policies), ()))]
for transform_id, probability, magnitude_idx in policy:
if not torch.rand(()) <= probability:
continue
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
magnitudes = magnitudes_fn(10, image_size)
if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
sample = self._apply_transform(sample, params, transform_id, magnitude)
return sample
class RandAugment(_AutoAugmentBase):
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, image_size: None, False),
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (
lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round()
.int(),
False,
),
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, image_size: None, False),
"Equalize": (lambda num_bins, image_size: None, False),
}
def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.num_ops = num_ops
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
params = params or self._get_params(sample)
image = query_image(sample)
image_size = F.get_image_size(image)
for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
sample = self._apply_transform(sample, params, transform_id, magnitude)
return sample
class TrivialAugmentWide(_AutoAugmentBase):
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, image_size: None, False),
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 135.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"Posterize": (
lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)))
.round()
.int(),
False,
),
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, image_size: None, False),
"Equalize": (lambda num_bins, image_size: None, False),
}
def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any):
super().__init__(**kwargs)
self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
params = params or self._get_params(sample)
image = query_image(sample)
image_size = F.get_image_size(image)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
return self._apply_transform(sample, params, transform_id, magnitude)
from typing import Any, Optional, Dict
import torch
from ._transform import Transform
class Compose(Transform):
def __init__(self, *transforms: Transform):
super().__init__()
self.transforms = transforms
for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform)
def forward(self, *inputs: Any) -> Any: # type: ignore[override]
sample = inputs if len(inputs) > 1 else inputs[0]
for transform in self.transforms:
sample = transform(sample)
return sample
class RandomApply(Transform):
def __init__(self, transform: Transform, *, p: float = 0.5) -> None:
super().__init__()
self.transform = transform
self.p = p
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if float(torch.rand(())) < self.p:
return sample
return self.transform(sample, params=params)
def extra_repr(self) -> str:
return f"p={self.p}"
class RandomChoice(Transform):
def __init__(self, *transforms: Transform):
super().__init__()
self.transforms = transforms
for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform)
def forward(self, *inputs: Any) -> Any: # type: ignore[override]
idx = int(torch.randint(len(self.transforms), size=()))
transform = self.transforms[idx]
return transform(*inputs)
class RandomOrder(Transform):
def __init__(self, *transforms: Transform):
super().__init__()
self.transforms = transforms
for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform)
def forward(self, *inputs: Any) -> Any: # type: ignore[override]
for idx in torch.randperm(len(self.transforms)):
transform = self.transforms[idx]
inputs = transform(*inputs)
return inputs
import math
import warnings
from typing import Any, Dict, List, Union, Sequence, Tuple, cast
import torch
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
from ._utils import query_image
class HorizontalFlip(Transform):
_DISPATCHER = F.horizontal_flip
class Resize(Transform):
_DISPATCHER = F.resize
def __init__(
self,
size: Union[int, Sequence[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self.size = size
self.interpolation = interpolation
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(size=self.size, interpolation=self.interpolation)
class CenterCrop(Transform):
_DISPATCHER = F.center_crop
def __init__(self, output_size: List[int]):
super().__init__()
self.output_size = output_size
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(output_size=self.output_size)
class RandomResizedCrop(Transform):
_DISPATCHER = F.resized_crop
def __init__(
self,
size: Union[int, Sequence[int]],
scale: Tuple[float, float] = (0.08, 1.0),
ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
if not isinstance(scale, Sequence):
raise TypeError("Scale should be a sequence")
scale = cast(Tuple[float, float], scale)
if not isinstance(ratio, Sequence):
raise TypeError("Ratio should be a sequence")
ratio = cast(Tuple[float, float], ratio)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
self.size = size
self.scale = scale
self.ratio = ratio
self.interpolation = interpolation
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
height, width = F.get_image_size(image)
area = height * width
log_ratio = torch.log(torch.tensor(self.ratio))
for _ in range(10):
target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
)
).item()
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = torch.randint(0, height - h + 1, size=(1,)).item()
j = torch.randint(0, width - w + 1, size=(1,)).item()
break
else:
# Fallback to central crop
in_ratio = float(width) / float(height)
if in_ratio < min(self.ratio):
w = width
h = int(round(w / min(self.ratio)))
elif in_ratio > max(self.ratio):
h = height
w = int(round(h * max(self.ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return dict(top=i, left=j, height=h, width=w, size=self.size)
from typing import Union, Any, Dict, Optional
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
from torchvision.transforms.functional import convert_image_dtype
class ConvertBoundingBoxFormat(Transform):
_DISPATCHER = F.convert_format
def __init__(
self,
format: Union[str, features.BoundingBoxFormat],
old_format: Optional[Union[str, features.BoundingBoxFormat]] = None,
) -> None:
super().__init__()
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
self.format = format
if isinstance(old_format, str):
old_format = features.BoundingBoxFormat[old_format]
self.old_format = old_format
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(format=self.format, old_format=self.old_format)
class ConvertImageDtype(Transform):
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
self.dtype = dtype
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not isinstance(input, features.Image):
return input
output = convert_image_dtype(input, dtype=self.dtype)
return features.Image.new_like(input, output, dtype=self.dtype)
class ConvertColorSpace(Transform):
_DISPATCHER = F.convert_color_space
def __init__(self, color_space: Union[str, features.ColorSpace]) -> None:
super().__init__()
if isinstance(color_space, str):
color_space = features.ColorSpace[color_space]
self.color_space = color_space
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(color_space=self.color_space)
import functools
from typing import Any, List, Type, Callable, Dict
import torch
from torchvision.prototype.transforms import Transform, functional as F
class Identity(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
return input
class Lambda(Transform):
def __init__(self, fn: Callable[[Any], Any], *types: Type):
super().__init__()
self.fn = fn
self.types = types
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not isinstance(input, self.types):
return input
return self.fn(input)
def extra_repr(self) -> str:
extras = []
name = getattr(self.fn, "__name__", None)
if name:
extras.append(name)
extras.append(f"types={[type.__name__ for type in self.types]}")
return ", ".join(extras)
class Normalize(Transform):
_DISPATCHER = F.normalize
def __init__(self, mean: List[float], std: List[float]):
super().__init__()
self.mean = mean
self.std = std
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(mean=self.mean, std=self.std)
class ToDtype(Lambda):
def __init__(self, dtype: torch.dtype, *types: Type) -> None:
self.dtype = dtype
super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types)
def extra_repr(self) -> str:
return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"])
import enum
import functools
from typing import Any, Dict, Optional
from torch import nn
from torchvision.prototype.utils._internal import apply_recursively
from torchvision.utils import _log_api_usage_once
from .functional._utils import Dispatcher
class Transform(nn.Module):
_DISPATCHER: Optional[Dispatcher] = None
def __init__(self) -> None:
super().__init__()
_log_api_usage_once(self)
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict()
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not self._DISPATCHER:
raise NotImplementedError()
if input not in self._DISPATCHER:
return input
return self._DISPATCHER(input, **params)
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
return apply_recursively(functools.partial(self._transform, params=params or self._get_params(sample)), sample)
def extra_repr(self) -> str:
extra = []
for name, value in self.__dict__.items():
if name.startswith("_") or name == "training":
continue
if not isinstance(value, (bool, int, float, str, tuple, list, enum.Enum)):
continue
extra.append(f"{name}={value}")
return ", ".join(extra)
from typing import Any, Dict
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, kernels as K
class DecodeImage(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not isinstance(input, features.EncodedImage):
return input
return features.Image(K.decode_image_with_pil(input))
class LabelToOneHot(Transform):
def __init__(self, num_categories: int = -1):
super().__init__()
self.num_categories = num_categories
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not isinstance(input, features.Label):
return input
num_categories = self.num_categories
if num_categories == -1 and input.categories is not None:
num_categories = len(input.categories)
return features.OneHotLabel(
K.label_to_one_hot(input, num_categories=num_categories), categories=input.categories
)
def extra_repr(self) -> str:
if self.num_categories == -1:
return ""
return f"num_categories={self.num_categories}"
from typing import Any, Union, Optional
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.utils._internal import query_recursively
def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]:
if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image):
return input
return None
try:
return next(query_recursively(fn, sample))
except StopIteration:
raise TypeError("No image was found in the sample")
...@@ -11,4 +11,5 @@ from ._color import ( ...@@ -11,4 +11,5 @@ from ._color import (
invert, invert,
) )
from ._geometry import horizontal_flip, resize, center_crop, resized_crop, affine, rotate from ._geometry import horizontal_flip, resize, center_crop, resized_crop, affine, rotate
from ._misc import normalize from ._meta_conversion import convert_color_space, convert_format
from ._misc import normalize, get_image_size, get_image_num_channels
from typing import TypeVar, Any from typing import Any
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
...@@ -7,8 +7,6 @@ from torchvision.transforms import functional as _F ...@@ -7,8 +7,6 @@ from torchvision.transforms import functional as _F
from ._utils import dispatch from ._utils import dispatch
T = TypeVar("T", bound=features._Feature)
@dispatch( @dispatch(
{ {
...@@ -16,7 +14,7 @@ T = TypeVar("T", bound=features._Feature) ...@@ -16,7 +14,7 @@ T = TypeVar("T", bound=features._Feature)
features.Image: K.erase_image, features.Image: K.erase_image,
} }
) )
def erase(input: T, *args: Any, **kwargs: Any) -> T: def erase(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -27,18 +25,18 @@ def erase(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -27,18 +25,18 @@ def erase(input: T, *args: Any, **kwargs: Any) -> T:
features.OneHotLabel: K.mixup_one_hot_label, features.OneHotLabel: K.mixup_one_hot_label,
} }
) )
def mixup(input: T, *args: Any, **kwargs: Any) -> T: def mixup(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
@dispatch( @dispatch(
{ {
features.Image: K.cutmix_image, features.Image: None,
features.OneHotLabel: K.cutmix_one_hot_label, features.OneHotLabel: None,
} }
) )
def cutmix(input: T, *args: Any, **kwargs: Any) -> T: def cutmix(input: Any, *args: Any, **kwargs: Any) -> Any:
"""Perform the CutMix operation as introduced in the paper """Perform the CutMix operation as introduced in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" <https://arxiv.org/abs/1905.04899>`_. `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" <https://arxiv.org/abs/1905.04899>`_.
...@@ -54,4 +52,13 @@ def cutmix(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -54,4 +52,13 @@ def cutmix(input: T, *args: Any, **kwargs: Any) -> T:
Please refer to the kernel documentations for a detailed explanation of the functionality and parameters. Please refer to the kernel documentations for a detailed explanation of the functionality and parameters.
""" """
... if isinstance(input, features.Image):
kwargs.pop("lam_adjusted", None)
output = K.cutmix_image(input, **kwargs)
return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel):
kwargs.pop("box", None)
output = K.cutmix_one_hot_label(input, **kwargs)
return features.OneHotLabel.new_like(input, output)
raise RuntimeError
from typing import TypeVar, Any from typing import Any
import PIL.Image import PIL.Image
import torch import torch
...@@ -8,8 +8,6 @@ from torchvision.transforms import functional as _F ...@@ -8,8 +8,6 @@ from torchvision.transforms import functional as _F
from ._utils import dispatch from ._utils import dispatch
T = TypeVar("T", bound=features._Feature)
@dispatch( @dispatch(
{ {
...@@ -18,7 +16,7 @@ T = TypeVar("T", bound=features._Feature) ...@@ -18,7 +16,7 @@ T = TypeVar("T", bound=features._Feature)
features.Image: K.adjust_brightness_image, features.Image: K.adjust_brightness_image,
} }
) )
def adjust_brightness(input: T, *args: Any, **kwargs: Any) -> T: def adjust_brightness(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -30,7 +28,7 @@ def adjust_brightness(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -30,7 +28,7 @@ def adjust_brightness(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.adjust_saturation_image, features.Image: K.adjust_saturation_image,
} }
) )
def adjust_saturation(input: T, *args: Any, **kwargs: Any) -> T: def adjust_saturation(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -42,7 +40,7 @@ def adjust_saturation(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -42,7 +40,7 @@ def adjust_saturation(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.adjust_contrast_image, features.Image: K.adjust_contrast_image,
} }
) )
def adjust_contrast(input: T, *args: Any, **kwargs: Any) -> T: def adjust_contrast(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -54,7 +52,7 @@ def adjust_contrast(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -54,7 +52,7 @@ def adjust_contrast(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.adjust_sharpness_image, features.Image: K.adjust_sharpness_image,
} }
) )
def adjust_sharpness(input: T, *args: Any, **kwargs: Any) -> T: def adjust_sharpness(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -66,7 +64,7 @@ def adjust_sharpness(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -66,7 +64,7 @@ def adjust_sharpness(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.posterize_image, features.Image: K.posterize_image,
} }
) )
def posterize(input: T, *args: Any, **kwargs: Any) -> T: def posterize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -78,7 +76,7 @@ def posterize(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -78,7 +76,7 @@ def posterize(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.solarize_image, features.Image: K.solarize_image,
} }
) )
def solarize(input: T, *args: Any, **kwargs: Any) -> T: def solarize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -90,7 +88,7 @@ def solarize(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -90,7 +88,7 @@ def solarize(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.autocontrast_image, features.Image: K.autocontrast_image,
} }
) )
def autocontrast(input: T, *args: Any, **kwargs: Any) -> T: def autocontrast(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -102,7 +100,7 @@ def autocontrast(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -102,7 +100,7 @@ def autocontrast(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.equalize_image, features.Image: K.equalize_image,
} }
) )
def equalize(input: T, *args: Any, **kwargs: Any) -> T: def equalize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -114,7 +112,7 @@ def equalize(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -114,7 +112,7 @@ def equalize(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.invert_image, features.Image: K.invert_image,
} }
) )
def invert(input: T, *args: Any, **kwargs: Any) -> T: def invert(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -126,7 +124,7 @@ def invert(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -126,7 +124,7 @@ def invert(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.adjust_hue_image, features.Image: K.adjust_hue_image,
} }
) )
def adjust_hue(input: T, *args: Any, **kwargs: Any) -> T: def adjust_hue(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -138,6 +136,6 @@ def adjust_hue(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -138,6 +136,6 @@ def adjust_hue(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.adjust_gamma_image, features.Image: K.adjust_gamma_image,
} }
) )
def adjust_gamma(input: T, *args: Any, **kwargs: Any) -> T: def adjust_gamma(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
from typing import TypeVar, Any, cast from typing import Any
import PIL.Image import PIL.Image
import torch import torch
...@@ -8,8 +8,6 @@ from torchvision.transforms import functional as _F ...@@ -8,8 +8,6 @@ from torchvision.transforms import functional as _F
from ._utils import dispatch from ._utils import dispatch
T = TypeVar("T", bound=features._Feature)
@dispatch( @dispatch(
{ {
...@@ -19,11 +17,11 @@ T = TypeVar("T", bound=features._Feature) ...@@ -19,11 +17,11 @@ T = TypeVar("T", bound=features._Feature)
features.BoundingBox: None, features.BoundingBox: None,
}, },
) )
def horizontal_flip(input: T, *args: Any, **kwargs: Any) -> T: def horizontal_flip(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
if isinstance(input, features.BoundingBox): if isinstance(input, features.BoundingBox):
output = K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) output = K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size)
return cast(T, features.BoundingBox.new_like(input, output)) return features.BoundingBox.new_like(input, output)
raise RuntimeError raise RuntimeError
...@@ -37,12 +35,12 @@ def horizontal_flip(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -37,12 +35,12 @@ def horizontal_flip(input: T, *args: Any, **kwargs: Any) -> T:
features.BoundingBox: None, features.BoundingBox: None,
} }
) )
def resize(input: T, *args: Any, **kwargs: Any) -> T: def resize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
if isinstance(input, features.BoundingBox): if isinstance(input, features.BoundingBox):
size = kwargs.pop("size") size = kwargs.pop("size")
output = K.resize_bounding_box(input, size=size, image_size=input.image_size) output = K.resize_bounding_box(input, size=size, image_size=input.image_size)
return cast(T, features.BoundingBox.new_like(input, output, image_size=size)) return features.BoundingBox.new_like(input, output, image_size=size)
raise RuntimeError raise RuntimeError
...@@ -54,7 +52,7 @@ def resize(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -54,7 +52,7 @@ def resize(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.center_crop_image, features.Image: K.center_crop_image,
} }
) )
def center_crop(input: T, *args: Any, **kwargs: Any) -> T: def center_crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -66,7 +64,7 @@ def center_crop(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -66,7 +64,7 @@ def center_crop(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.resized_crop_image, features.Image: K.resized_crop_image,
} }
) )
def resized_crop(input: T, *args: Any, **kwargs: Any) -> T: def resized_crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -78,7 +76,7 @@ def resized_crop(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -78,7 +76,7 @@ def resized_crop(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.affine_image, features.Image: K.affine_image,
} }
) )
def affine(input: T, *args: Any, **kwargs: Any) -> T: def affine(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -90,7 +88,7 @@ def affine(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -90,7 +88,7 @@ def affine(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.rotate_image, features.Image: K.rotate_image,
} }
) )
def rotate(input: T, *args: Any, **kwargs: Any) -> T: def rotate(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -102,7 +100,7 @@ def rotate(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -102,7 +100,7 @@ def rotate(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.pad_image, features.Image: K.pad_image,
} }
) )
def pad(input: T, *args: Any, **kwargs: Any) -> T: def pad(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -114,7 +112,7 @@ def pad(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -114,7 +112,7 @@ def pad(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.crop_image, features.Image: K.crop_image,
} }
) )
def crop(input: T, *args: Any, **kwargs: Any) -> T: def crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -126,7 +124,7 @@ def crop(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -126,7 +124,7 @@ def crop(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.perspective_image, features.Image: K.perspective_image,
} }
) )
def perspective(input: T, *args: Any, **kwargs: Any) -> T: def perspective(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -138,7 +136,7 @@ def perspective(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -138,7 +136,7 @@ def perspective(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.vertical_flip_image, features.Image: K.vertical_flip_image,
} }
) )
def vertical_flip(input: T, *args: Any, **kwargs: Any) -> T: def vertical_flip(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -150,7 +148,7 @@ def vertical_flip(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -150,7 +148,7 @@ def vertical_flip(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.five_crop_image, features.Image: K.five_crop_image,
} }
) )
def five_crop(input: T, *args: Any, **kwargs: Any) -> T: def five_crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -162,6 +160,6 @@ def five_crop(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -162,6 +160,6 @@ def five_crop(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.ten_crop_image, features.Image: K.ten_crop_image,
} }
) )
def ten_crop(input: T, *args: Any, **kwargs: Any) -> T: def ten_crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
from typing import Any
import PIL.Image
import torch
from torchvision.ops import box_convert
from torchvision.prototype import features
from torchvision.prototype.transforms import kernels as K
from torchvision.transforms import functional as _F
from ._utils import dispatch
@dispatch(
{
torch.Tensor: None,
features.BoundingBox: None,
}
)
def convert_format(input: Any, *args: Any, **kwargs: Any) -> Any:
format = kwargs["format"]
if type(input) is torch.Tensor:
old_format = kwargs.get("old_format")
if old_format is None:
raise TypeError("For vanilla tensors the `old_format` needs to be provided.")
return box_convert(input, in_fmt=kwargs["old_format"].name.lower(), out_fmt=format.name.lower())
elif isinstance(input, features.BoundingBox):
output = K.convert_bounding_box_format(input, old_format=input.format, new_format=kwargs["format"])
return features.BoundingBox.new_like(input, output, format=format)
raise RuntimeError
@dispatch(
{
torch.Tensor: None,
PIL.Image.Image: None,
features.Image: None,
}
)
def convert_color_space(input: Any, *args: Any, **kwargs: Any) -> Any:
color_space = kwargs["color_space"]
if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image):
if color_space != features.ColorSpace.GRAYSCALE:
raise ValueError("For vanilla tensors and PIL images only RGB to grayscale is supported")
return _F.rgb_to_grayscale(input)
elif isinstance(input, features.Image):
output = K.convert_color_space(input, old_color_space=input.color_space, new_color_space=color_space)
return features.Image.new_like(input, output, color_space=color_space)
raise RuntimeError
from typing import TypeVar, Any from typing import Any
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import kernels as K from torchvision.prototype.transforms import kernels as K
from torchvision.transforms import functional as _F from torchvision.transforms import functional as _F
from torchvision.transforms.functional_pil import (
get_image_size as _get_image_size_pil,
get_image_num_channels as _get_image_num_channels_pil,
)
from torchvision.transforms.functional_tensor import (
get_image_size as _get_image_size_tensor,
get_image_num_channels as _get_image_num_channels_tensor,
)
from ._utils import dispatch from ._utils import dispatch
T = TypeVar("T", bound=features._Feature)
@dispatch( @dispatch(
{ {
...@@ -17,7 +23,7 @@ T = TypeVar("T", bound=features._Feature) ...@@ -17,7 +23,7 @@ T = TypeVar("T", bound=features._Feature)
features.Image: K.normalize_image, features.Image: K.normalize_image,
} }
) )
def normalize(input: T, *args: Any, **kwargs: Any) -> T: def normalize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
...@@ -29,6 +35,35 @@ def normalize(input: T, *args: Any, **kwargs: Any) -> T: ...@@ -29,6 +35,35 @@ def normalize(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.gaussian_blur_image, features.Image: K.gaussian_blur_image,
} }
) )
def ten_gaussian_blur(input: T, *args: Any, **kwargs: Any) -> T: def gaussian_blur(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" """TODO: add docstring"""
... ...
@dispatch(
{
torch.Tensor: _get_image_size_tensor,
PIL.Image.Image: _get_image_size_pil,
features.Image: None,
features.BoundingBox: None,
}
)
def get_image_size(input: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(input, (features.Image, features.BoundingBox)):
return list(input.image_size)
raise RuntimeError
@dispatch(
{
torch.Tensor: _get_image_num_channels_tensor,
PIL.Image.Image: _get_image_num_channels_pil,
features.Image: None,
}
)
def get_image_num_channels(input: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(input, features.Image):
return input.num_channels
raise RuntimeError
import functools
import inspect import inspect
from typing import Any, Optional, Callable, TypeVar, Dict from typing import Any, Optional, Callable, TypeVar, Mapping, Type
import torch import torch
import torch.overrides import torch.overrides
...@@ -9,10 +8,10 @@ from torchvision.prototype import features ...@@ -9,10 +8,10 @@ from torchvision.prototype import features
F = TypeVar("F", bound=features._Feature) F = TypeVar("F", bound=features._Feature)
def dispatch(kernels: Dict[Any, Optional[Callable]]) -> Callable[[Callable[..., F]], Callable[..., F]]: class Dispatcher:
"""Decorates a function to automatically dispatch to registered kernels based on the call arguments. """Wrap a function to automatically dispatch to registered kernels based on the call arguments.
The dispatch function should have this signature The wrapped function should have this signature
.. code:: python .. code:: python
...@@ -34,7 +33,19 @@ def dispatch(kernels: Dict[Any, Optional[Callable]]) -> Callable[[Callable[..., ...@@ -34,7 +33,19 @@ def dispatch(kernels: Dict[Any, Optional[Callable]]) -> Callable[[Callable[...,
TypeError: If the decorated function is called with an input that cannot be dispatched. TypeError: If the decorated function is called with an input that cannot be dispatched.
""" """
def check_kernel(kernel: Any) -> bool: def __init__(self, fn: Callable, kernels: Mapping[Type, Optional[Callable]]):
self._fn = fn
for feature_type, kernel in kernels.items():
if not self._check_kernel(kernel):
raise TypeError(
f"Kernel for feature type {feature_type.__name__} is not callable with "
f"kernel(input, *args, **kwargs)."
)
self._kernels = kernels
def _check_kernel(self, kernel: Optional[Callable]) -> bool:
if kernel is None: if kernel is None:
return True return True
...@@ -47,43 +58,45 @@ def dispatch(kernels: Dict[Any, Optional[Callable]]) -> Callable[[Callable[..., ...@@ -47,43 +58,45 @@ def dispatch(kernels: Dict[Any, Optional[Callable]]) -> Callable[[Callable[...,
return params[0].kind != inspect.Parameter.KEYWORD_ONLY return params[0].kind != inspect.Parameter.KEYWORD_ONLY
for feature_type, kernel in kernels.items(): def _resolve(self, feature_type: Type) -> Optional[Callable]:
if not check_kernel(kernel):
raise TypeError(
f"Kernel for feature type {feature_type.__name__} is not callable with kernel(input, *args, **kwargs)."
)
def outer_wrapper(dispatch_fn: Callable[..., F]) -> Callable[..., F]:
@functools.wraps(dispatch_fn)
def inner_wrapper(input: F, *args: Any, **kwargs: Any) -> F:
feature_type = type(input)
try: try:
kernel = kernels[feature_type] return self._kernels[feature_type]
except KeyError: except KeyError:
try: try:
feature_type, kernel = next( return next(
(feature_type, kernel) kernel
for feature_type, kernel in kernels.items() for registered_feature_type, kernel in self._kernels.items()
if isinstance(input, feature_type) if issubclass(feature_type, registered_feature_type)
) )
except StopIteration: except StopIteration:
raise TypeError(f"No support for {type(input).__name__}") from None raise TypeError(f"No support for feature type {feature_type.__name__}") from None
def __contains__(self, obj: Any) -> bool:
try:
self._resolve(type(obj))
return True
except TypeError:
return False
def __call__(self, input: Any, *args: Any, **kwargs: Any) -> Any:
kernel = self._resolve(type(input))
if kernel is None: if kernel is None:
output = dispatch_fn(input, *args, **kwargs) output = self._fn(input, *args, **kwargs)
if output is None: if output is None:
raise RuntimeError( raise RuntimeError(
f"{dispatch_fn.__name__}() did not handle inputs of type {type(input).__name__} " f"{self._fn.__name__}() did not handle inputs of type {type(input).__name__} "
f"although it was configured to do so." f"although it was configured to do so."
) )
else: else:
output = kernel(input, *args, **kwargs) output = kernel(input, *args, **kwargs)
if issubclass(feature_type, features._Feature) and type(output) is torch.Tensor: if isinstance(input, features._Feature) and type(output) is torch.Tensor:
output = feature_type.new_like(input, output) output = type(input).new_like(input, output)
return output return output
return inner_wrapper
return outer_wrapper def dispatch(kernels: Mapping[Type, Optional[Callable]]) -> Callable[[Callable], Dispatcher]:
"""Decorates a function and turns it into a :class:`Dispatcher`."""
return lambda fn: Dispatcher(fn, kernels)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment