Unverified Commit 52e6bd08 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add prototype transforms that use the prototype dispatchers (#5418)

* add prototype transforms that use the prototype dispatchers

Conflicts:
	torchvision/prototype/transforms/__init__.py

* simplify

* add logger

* remove legacy classes

Conflicts:
	torchvision/prototype/transforms/_augment.py
	torchvision/prototype/transforms/_auto_augment.py
	torchvision/prototype/transforms/_geometry.py

* make get_params private

* remove randbool method

* remove AutoAugmentDispatcher

* add high level kernels for meta conversion

* remove transforms meta abstraction from auto augment transforms

* appease mypy

* add smoke tests for transforms

* remove Query object

* remove extra_repr helper

* fix tests

* appease mypy

* revert some changes on the kernel tests

* fix dispatcher annotations

* remove float cast for torch.rand

* add helper to query image

* fix imports

* address auto augment comments

* cleanup
parent 144f0980
......@@ -99,7 +99,6 @@ class TestCommon:
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
)
@pytest.mark.xfail
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
......
import itertools
import PIL.Image
import pytest
import torch
from test_prototype_transforms_kernels import make_images, make_bounding_boxes, make_one_hot_labels
from torchvision.prototype import transforms, features
from torchvision.transforms.functional import to_pil_image
def make_vanilla_tensor_images(*args, **kwargs):
for image in make_images(*args, **kwargs):
if image.ndim > 3:
continue
yield image.data
def make_pil_images(*args, **kwargs):
for image in make_vanilla_tensor_images(*args, **kwargs):
yield to_pil_image(image)
def make_vanilla_tensor_bounding_boxes(*args, **kwargs):
for bounding_box in make_bounding_boxes(*args, **kwargs):
yield bounding_box.data
INPUT_CREATIONS_FNS = {
features.Image: make_images,
features.BoundingBox: make_bounding_boxes,
features.OneHotLabel: make_one_hot_labels,
torch.Tensor: make_vanilla_tensor_images,
PIL.Image.Image: make_pil_images,
}
def parametrize(transforms_with_inputs):
return pytest.mark.parametrize(
("transform", "input"),
[
pytest.param(
transform,
input,
id=f"{type(transform).__name__}-{type(input).__module__}.{type(input).__name__}-{idx}",
)
for transform, inputs in transforms_with_inputs
for idx, input in enumerate(inputs)
],
)
def parametrize_from_transforms(*transforms):
transforms_with_inputs = []
for transform in transforms:
dispatcher = transform._DISPATCHER
if dispatcher is None:
continue
for type_ in dispatcher._kernels:
try:
inputs = INPUT_CREATIONS_FNS[type_]()
except KeyError:
continue
transforms_with_inputs.append((transform, inputs))
return parametrize(transforms_with_inputs)
class TestSmoke:
@parametrize_from_transforms(
transforms.RandomErasing(),
transforms.HorizontalFlip(),
transforms.Resize([16, 16]),
transforms.CenterCrop([16, 16]),
transforms.ConvertImageDtype(),
)
def test_common(self, transform, input):
transform(input)
@parametrize(
[
(
transform,
[
dict(
image=features.Image.new_like(image, image.unsqueeze(0), dtype=torch.float),
one_hot_label=features.OneHotLabel.new_like(
one_hot_label, one_hot_label.unsqueeze(0), dtype=torch.float
),
)
for image, one_hot_label in itertools.product(make_images(), make_one_hot_labels())
],
)
for transform in [
transforms.RandomMixup(alpha=1.0),
transforms.RandomCutmix(alpha=1.0),
]
]
)
def test_mixup_cutmix(self, transform, input):
transform(input)
@parametrize(
[
(
transform,
itertools.chain.from_iterable(
fn(dtypes=[torch.uint8], extra_dims=[(4,)])
for fn in [
make_images,
make_vanilla_tensor_images,
make_pil_images,
]
),
)
for transform in (
transforms.RandAugment(),
transforms.TrivialAugmentWide(),
transforms.AutoAugment(),
)
]
)
def test_auto_augment(self, transform, input):
transform(input)
@parametrize(
[
(
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
itertools.chain.from_iterable(
fn(color_spaces=["rgb"], dtypes=[torch.float32])
for fn in [
make_images,
make_vanilla_tensor_images,
]
),
),
]
)
def test_normalize(self, transform, input):
transform(input)
@parametrize(
[
(
transforms.ConvertColorSpace("grayscale"),
itertools.chain(
make_images(),
make_vanilla_tensor_images(color_spaces=["rgb"]),
make_pil_images(color_spaces=["rgb"]),
),
)
]
)
def test_convert_bounding_color_space(self, transform, input):
transform(input)
@parametrize(
[
(
transforms.ConvertBoundingBoxFormat("xyxy", old_format="xywh"),
itertools.chain(
make_bounding_boxes(),
make_vanilla_tensor_bounding_boxes(formats=["xywh"]),
),
)
]
)
def test_convert_bounding_box_format(self, transform, input):
transform(input)
@parametrize(
[
(
transforms.RandomResizedCrop([16, 16]),
itertools.chain(
make_images(extra_dims=[(4,)]),
make_vanilla_tensor_images(),
make_pil_images(),
),
)
]
)
def test_random_resized_crop(self, transform, input):
transform(input)
......@@ -5,6 +5,7 @@ import pytest
import torch.testing
import torchvision.prototype.transforms.kernels as K
from torch import jit
from torch.nn.functional import one_hot
from torchvision.prototype import features
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
......@@ -39,10 +40,10 @@ def make_images(
extra_dims=((4,), (2, 3)),
):
for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes):
yield make_image(size, color_space=color_space)
yield make_image(size, color_space=color_space, dtype=dtype)
for color_space, extra_dims_ in itertools.product(color_spaces, extra_dims):
yield make_image(color_space=color_space, extra_dims=extra_dims_)
for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims):
yield make_image(color_space=color_space, extra_dims=extra_dims_, dtype=dtype)
def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
......@@ -106,6 +107,27 @@ def make_bounding_boxes(
yield make_bounding_box(format=format, extra_dims=extra_dims_)
def make_label(size=(), *, categories=("category0", "category1")):
return features.Label(torch.randint(0, len(categories) if categories else 10, size), categories=categories)
def make_one_hot_label(*args, **kwargs):
label = make_label(*args, **kwargs)
return features.OneHotLabel(one_hot(label, num_classes=len(label.categories)), categories=label.categories)
def make_one_hot_labels(
*,
num_categories=(1, 2, 10),
extra_dims=((4,), (2, 3)),
):
for num_categories_ in num_categories:
yield make_one_hot_label(categories=[f"category{idx}" for idx in range(num_categories_)])
for extra_dims_ in extra_dims:
yield make_one_hot_label(extra_dims_)
class SampleInput:
def __init__(self, *args, **kwargs):
self.args = args
......
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
from . import kernels # usort: skip
from . import functional # usort: skip
from .kernels import InterpolationMode # usort: skip
from ._transform import Transform # usort: skip
from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
from ._type_conversion import DecodeImage, LabelToOneHot
import math
import numbers
import warnings
from typing import Any, Dict, Tuple
import torch
from torchvision.prototype.transforms import Transform, functional as F
from ._utils import query_image
class RandomErasing(Transform):
_DISPATCHER = F.erase
def __init__(
self,
p: float = 0.5,
scale: Tuple[float, float] = (0.02, 0.33),
ratio: Tuple[float, float] = (0.3, 3.3),
value: float = 0,
):
super().__init__()
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
raise ValueError("If value is str, it should be 'random'")
if not isinstance(scale, (tuple, list)):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, (tuple, list)):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("Scale should be between 0 and 1")
if p < 0 or p > 1:
raise ValueError("Random erasing probability should be between 0 and 1")
# TODO: deprecate p in favor of wrapping the transform in a RandomApply
self.p = p
self.scale = scale
self.ratio = ratio
self.value = value
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
img_h, img_w = F.get_image_size(image)
img_c = F.get_image_num_channels(image)
if isinstance(self.value, (int, float)):
value = [self.value]
elif isinstance(self.value, str):
value = None
elif isinstance(self.value, tuple):
value = list(self.value)
else:
value = self.value
if value is not None and not (len(value) in (1, img_c)):
raise ValueError(
f"If value is a sequence, it should have either a single value or {img_c} (number of input channels)"
)
area = img_h * img_w
log_ratio = torch.log(torch.tensor(self.ratio))
for _ in range(10):
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
)
).item()
h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio)))
if not (h < img_h and w < img_w):
continue
if value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else:
v = torch.tensor(value)[:, None, None]
i = torch.randint(0, img_h - h + 1, size=(1,)).item()
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
break
else:
i, j, h, w, v = 0, 0, img_h, img_w, image
return dict(zip("ijhwv", (i, j, h, w, v)))
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if torch.rand(1) >= self.p:
return input
return super()._transform(input, params)
class RandomMixup(Transform):
_DISPATCHER = F.mixup
def __init__(self, *, alpha: float) -> None:
super().__init__()
self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(())))
class RandomCutmix(Transform):
_DISPATCHER = F.cutmix
def __init__(self, *, alpha: float) -> None:
super().__init__()
self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def _get_params(self, sample: Any) -> Dict[str, Any]:
lam = float(self._dist.sample(()))
image = query_image(sample)
H, W = F.get_image_size(image)
r_x = torch.randint(W, ())
r_y = torch.randint(H, ())
r = 0.5 * math.sqrt(1.0 - lam)
r_w_half = int(r * W)
r_h_half = int(r * H)
x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))
box = (x1, y1, x2, y2)
lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
return dict(box=box, lam_adjusted=lam_adjusted)
import math
from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F
from torchvision.prototype.utils._internal import apply_recursively
from ._utils import query_image
K = TypeVar("K")
V = TypeVar("V")
class _AutoAugmentBase(Transform):
def __init__(
self, *, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None
) -> None:
super().__init__()
self.interpolation = interpolation
self.fill = fill
_DISPATCHER_MAP: Dict[str, Callable[[Any, float, InterpolationMode, Optional[List[float]]], Any]] = {
"Identity": lambda input, magnitude, interpolation, fill: input,
"ShearX": lambda input, magnitude, interpolation, fill: F.affine(
input,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[math.degrees(magnitude), 0.0],
interpolation=interpolation,
fill=fill,
),
"ShearY": lambda input, magnitude, interpolation, fill: F.affine(
input,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[0.0, math.degrees(magnitude)],
interpolation=interpolation,
fill=fill,
),
"TranslateX": lambda input, magnitude, interpolation, fill: F.affine(
input,
angle=0.0,
translate=[int(magnitude), 0],
scale=1.0,
shear=[0.0, 0.0],
interpolation=interpolation,
fill=fill,
),
"TranslateY": lambda input, magnitude, interpolation, fill: F.affine(
input,
angle=0.0,
translate=[0, int(magnitude)],
scale=1.0,
shear=[0.0, 0.0],
interpolation=interpolation,
fill=fill,
),
"Rotate": lambda input, magnitude, interpolation, fill: F.rotate(input, angle=magnitude),
"Brightness": lambda input, magnitude, interpolation, fill: F.adjust_brightness(
input, brightness_factor=1.0 + magnitude
),
"Color": lambda input, magnitude, interpolation, fill: F.adjust_saturation(
input, saturation_factor=1.0 + magnitude
),
"Contrast": lambda input, magnitude, interpolation, fill: F.adjust_contrast(
input, contrast_factor=1.0 + magnitude
),
"Sharpness": lambda input, magnitude, interpolation, fill: F.adjust_sharpness(
input, sharpness_factor=1.0 + magnitude
),
"Posterize": lambda input, magnitude, interpolation, fill: F.posterize(input, bits=int(magnitude)),
"Solarize": lambda input, magnitude, interpolation, fill: F.solarize(input, threshold=magnitude),
"AutoContrast": lambda input, magnitude, interpolation, fill: F.autocontrast(input),
"Equalize": lambda input, magnitude, interpolation, fill: F.equalize(input),
"Invert": lambda input, magnitude, interpolation, fill: F.invert(input),
}
def _is_supported(self, obj: Any) -> bool:
return type(obj) in {features.Image, torch.Tensor} or isinstance(obj, PIL.Image.Image)
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
num_channels = F.get_image_num_channels(image)
fill = self.fill
if isinstance(fill, (int, float)):
fill = [float(fill)] * num_channels
elif fill is not None:
fill = [float(f) for f in fill]
return dict(interpolation=self.interpolation, fill=fill)
def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
keys = tuple(dct.keys())
key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]
def _apply_transform(self, sample: Any, params: Dict[str, Any], transform_id: str, magnitude: float) -> Any:
dispatcher = self._DISPATCHER_MAP[transform_id]
def transform(input: Any) -> Any:
if not self._is_supported(input):
return input
return dispatcher(input, magnitude, params["interpolation"], params["fill"])
return apply_recursively(transform, sample)
class AutoAugment(_AutoAugmentBase):
_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (
lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round()
.int(),
False,
),
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, image_size: None, False),
"Equalize": (lambda num_bins, image_size: None, False),
"Invert": (lambda num_bins, image_size: None, False),
}
def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.policy = policy
self._policies = self._get_policies(policy)
def _get_policies(
self, policy: AutoAugmentPolicy
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
if policy == AutoAugmentPolicy.IMAGENET:
return [
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
]
elif policy == AutoAugmentPolicy.CIFAR10:
return [
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
]
elif policy == AutoAugmentPolicy.SVHN:
return [
(("ShearX", 0.9, 4), ("Invert", 0.2, None)),
(("ShearY", 0.9, 8), ("Invert", 0.7, None)),
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
(("ShearY", 0.9, 8), ("Invert", 0.4, None)),
(("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
(("ShearY", 0.8, 8), ("Invert", 0.7, None)),
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
(("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
(("Invert", 0.6, None), ("Rotate", 0.8, 4)),
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
(("ShearX", 0.1, 6), ("Invert", 0.6, None)),
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
(("ShearY", 0.8, 4), ("Invert", 0.8, None)),
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
]
else:
raise ValueError(f"The provided policy {policy} is not recognized.")
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
params = params or self._get_params(sample)
image = query_image(sample)
image_size = F.get_image_size(image)
policy = self._policies[int(torch.randint(len(self._policies), ()))]
for transform_id, probability, magnitude_idx in policy:
if not torch.rand(()) <= probability:
continue
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
magnitudes = magnitudes_fn(10, image_size)
if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
sample = self._apply_transform(sample, params, transform_id, magnitude)
return sample
class RandAugment(_AutoAugmentBase):
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, image_size: None, False),
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (
lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round()
.int(),
False,
),
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, image_size: None, False),
"Equalize": (lambda num_bins, image_size: None, False),
}
def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.num_ops = num_ops
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
params = params or self._get_params(sample)
image = query_image(sample)
image_size = F.get_image_size(image)
for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
sample = self._apply_transform(sample, params, transform_id, magnitude)
return sample
class TrivialAugmentWide(_AutoAugmentBase):
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, image_size: None, False),
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 135.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"Posterize": (
lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)))
.round()
.int(),
False,
),
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, image_size: None, False),
"Equalize": (lambda num_bins, image_size: None, False),
}
def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any):
super().__init__(**kwargs)
self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
params = params or self._get_params(sample)
image = query_image(sample)
image_size = F.get_image_size(image)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
return self._apply_transform(sample, params, transform_id, magnitude)
from typing import Any, Optional, Dict
import torch
from ._transform import Transform
class Compose(Transform):
def __init__(self, *transforms: Transform):
super().__init__()
self.transforms = transforms
for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform)
def forward(self, *inputs: Any) -> Any: # type: ignore[override]
sample = inputs if len(inputs) > 1 else inputs[0]
for transform in self.transforms:
sample = transform(sample)
return sample
class RandomApply(Transform):
def __init__(self, transform: Transform, *, p: float = 0.5) -> None:
super().__init__()
self.transform = transform
self.p = p
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if float(torch.rand(())) < self.p:
return sample
return self.transform(sample, params=params)
def extra_repr(self) -> str:
return f"p={self.p}"
class RandomChoice(Transform):
def __init__(self, *transforms: Transform):
super().__init__()
self.transforms = transforms
for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform)
def forward(self, *inputs: Any) -> Any: # type: ignore[override]
idx = int(torch.randint(len(self.transforms), size=()))
transform = self.transforms[idx]
return transform(*inputs)
class RandomOrder(Transform):
def __init__(self, *transforms: Transform):
super().__init__()
self.transforms = transforms
for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform)
def forward(self, *inputs: Any) -> Any: # type: ignore[override]
for idx in torch.randperm(len(self.transforms)):
transform = self.transforms[idx]
inputs = transform(*inputs)
return inputs
import math
import warnings
from typing import Any, Dict, List, Union, Sequence, Tuple, cast
import torch
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
from ._utils import query_image
class HorizontalFlip(Transform):
_DISPATCHER = F.horizontal_flip
class Resize(Transform):
_DISPATCHER = F.resize
def __init__(
self,
size: Union[int, Sequence[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self.size = size
self.interpolation = interpolation
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(size=self.size, interpolation=self.interpolation)
class CenterCrop(Transform):
_DISPATCHER = F.center_crop
def __init__(self, output_size: List[int]):
super().__init__()
self.output_size = output_size
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(output_size=self.output_size)
class RandomResizedCrop(Transform):
_DISPATCHER = F.resized_crop
def __init__(
self,
size: Union[int, Sequence[int]],
scale: Tuple[float, float] = (0.08, 1.0),
ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
if not isinstance(scale, Sequence):
raise TypeError("Scale should be a sequence")
scale = cast(Tuple[float, float], scale)
if not isinstance(ratio, Sequence):
raise TypeError("Ratio should be a sequence")
ratio = cast(Tuple[float, float], ratio)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
self.size = size
self.scale = scale
self.ratio = ratio
self.interpolation = interpolation
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
height, width = F.get_image_size(image)
area = height * width
log_ratio = torch.log(torch.tensor(self.ratio))
for _ in range(10):
target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
)
).item()
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = torch.randint(0, height - h + 1, size=(1,)).item()
j = torch.randint(0, width - w + 1, size=(1,)).item()
break
else:
# Fallback to central crop
in_ratio = float(width) / float(height)
if in_ratio < min(self.ratio):
w = width
h = int(round(w / min(self.ratio)))
elif in_ratio > max(self.ratio):
h = height
w = int(round(h * max(self.ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return dict(top=i, left=j, height=h, width=w, size=self.size)
from typing import Union, Any, Dict, Optional
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
from torchvision.transforms.functional import convert_image_dtype
class ConvertBoundingBoxFormat(Transform):
_DISPATCHER = F.convert_format
def __init__(
self,
format: Union[str, features.BoundingBoxFormat],
old_format: Optional[Union[str, features.BoundingBoxFormat]] = None,
) -> None:
super().__init__()
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
self.format = format
if isinstance(old_format, str):
old_format = features.BoundingBoxFormat[old_format]
self.old_format = old_format
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(format=self.format, old_format=self.old_format)
class ConvertImageDtype(Transform):
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
self.dtype = dtype
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not isinstance(input, features.Image):
return input
output = convert_image_dtype(input, dtype=self.dtype)
return features.Image.new_like(input, output, dtype=self.dtype)
class ConvertColorSpace(Transform):
_DISPATCHER = F.convert_color_space
def __init__(self, color_space: Union[str, features.ColorSpace]) -> None:
super().__init__()
if isinstance(color_space, str):
color_space = features.ColorSpace[color_space]
self.color_space = color_space
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(color_space=self.color_space)
import functools
from typing import Any, List, Type, Callable, Dict
import torch
from torchvision.prototype.transforms import Transform, functional as F
class Identity(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
return input
class Lambda(Transform):
def __init__(self, fn: Callable[[Any], Any], *types: Type):
super().__init__()
self.fn = fn
self.types = types
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not isinstance(input, self.types):
return input
return self.fn(input)
def extra_repr(self) -> str:
extras = []
name = getattr(self.fn, "__name__", None)
if name:
extras.append(name)
extras.append(f"types={[type.__name__ for type in self.types]}")
return ", ".join(extras)
class Normalize(Transform):
_DISPATCHER = F.normalize
def __init__(self, mean: List[float], std: List[float]):
super().__init__()
self.mean = mean
self.std = std
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(mean=self.mean, std=self.std)
class ToDtype(Lambda):
def __init__(self, dtype: torch.dtype, *types: Type) -> None:
self.dtype = dtype
super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types)
def extra_repr(self) -> str:
return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"])
import enum
import functools
from typing import Any, Dict, Optional
from torch import nn
from torchvision.prototype.utils._internal import apply_recursively
from torchvision.utils import _log_api_usage_once
from .functional._utils import Dispatcher
class Transform(nn.Module):
_DISPATCHER: Optional[Dispatcher] = None
def __init__(self) -> None:
super().__init__()
_log_api_usage_once(self)
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict()
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not self._DISPATCHER:
raise NotImplementedError()
if input not in self._DISPATCHER:
return input
return self._DISPATCHER(input, **params)
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
return apply_recursively(functools.partial(self._transform, params=params or self._get_params(sample)), sample)
def extra_repr(self) -> str:
extra = []
for name, value in self.__dict__.items():
if name.startswith("_") or name == "training":
continue
if not isinstance(value, (bool, int, float, str, tuple, list, enum.Enum)):
continue
extra.append(f"{name}={value}")
return ", ".join(extra)
from typing import Any, Dict
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, kernels as K
class DecodeImage(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not isinstance(input, features.EncodedImage):
return input
return features.Image(K.decode_image_with_pil(input))
class LabelToOneHot(Transform):
def __init__(self, num_categories: int = -1):
super().__init__()
self.num_categories = num_categories
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not isinstance(input, features.Label):
return input
num_categories = self.num_categories
if num_categories == -1 and input.categories is not None:
num_categories = len(input.categories)
return features.OneHotLabel(
K.label_to_one_hot(input, num_categories=num_categories), categories=input.categories
)
def extra_repr(self) -> str:
if self.num_categories == -1:
return ""
return f"num_categories={self.num_categories}"
from typing import Any, Union, Optional
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.utils._internal import query_recursively
def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]:
if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image):
return input
return None
try:
return next(query_recursively(fn, sample))
except StopIteration:
raise TypeError("No image was found in the sample")
......@@ -11,4 +11,5 @@ from ._color import (
invert,
)
from ._geometry import horizontal_flip, resize, center_crop, resized_crop, affine, rotate
from ._misc import normalize
from ._meta_conversion import convert_color_space, convert_format
from ._misc import normalize, get_image_size, get_image_num_channels
from typing import TypeVar, Any
from typing import Any
import torch
from torchvision.prototype import features
......@@ -7,8 +7,6 @@ from torchvision.transforms import functional as _F
from ._utils import dispatch
T = TypeVar("T", bound=features._Feature)
@dispatch(
{
......@@ -16,7 +14,7 @@ T = TypeVar("T", bound=features._Feature)
features.Image: K.erase_image,
}
)
def erase(input: T, *args: Any, **kwargs: Any) -> T:
def erase(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -27,18 +25,18 @@ def erase(input: T, *args: Any, **kwargs: Any) -> T:
features.OneHotLabel: K.mixup_one_hot_label,
}
)
def mixup(input: T, *args: Any, **kwargs: Any) -> T:
def mixup(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
features.Image: K.cutmix_image,
features.OneHotLabel: K.cutmix_one_hot_label,
features.Image: None,
features.OneHotLabel: None,
}
)
def cutmix(input: T, *args: Any, **kwargs: Any) -> T:
def cutmix(input: Any, *args: Any, **kwargs: Any) -> Any:
"""Perform the CutMix operation as introduced in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" <https://arxiv.org/abs/1905.04899>`_.
......@@ -54,4 +52,13 @@ def cutmix(input: T, *args: Any, **kwargs: Any) -> T:
Please refer to the kernel documentations for a detailed explanation of the functionality and parameters.
"""
...
if isinstance(input, features.Image):
kwargs.pop("lam_adjusted", None)
output = K.cutmix_image(input, **kwargs)
return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel):
kwargs.pop("box", None)
output = K.cutmix_one_hot_label(input, **kwargs)
return features.OneHotLabel.new_like(input, output)
raise RuntimeError
from typing import TypeVar, Any
from typing import Any
import PIL.Image
import torch
......@@ -8,8 +8,6 @@ from torchvision.transforms import functional as _F
from ._utils import dispatch
T = TypeVar("T", bound=features._Feature)
@dispatch(
{
......@@ -18,7 +16,7 @@ T = TypeVar("T", bound=features._Feature)
features.Image: K.adjust_brightness_image,
}
)
def adjust_brightness(input: T, *args: Any, **kwargs: Any) -> T:
def adjust_brightness(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -30,7 +28,7 @@ def adjust_brightness(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.adjust_saturation_image,
}
)
def adjust_saturation(input: T, *args: Any, **kwargs: Any) -> T:
def adjust_saturation(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -42,7 +40,7 @@ def adjust_saturation(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.adjust_contrast_image,
}
)
def adjust_contrast(input: T, *args: Any, **kwargs: Any) -> T:
def adjust_contrast(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -54,7 +52,7 @@ def adjust_contrast(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.adjust_sharpness_image,
}
)
def adjust_sharpness(input: T, *args: Any, **kwargs: Any) -> T:
def adjust_sharpness(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -66,7 +64,7 @@ def adjust_sharpness(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.posterize_image,
}
)
def posterize(input: T, *args: Any, **kwargs: Any) -> T:
def posterize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -78,7 +76,7 @@ def posterize(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.solarize_image,
}
)
def solarize(input: T, *args: Any, **kwargs: Any) -> T:
def solarize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -90,7 +88,7 @@ def solarize(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.autocontrast_image,
}
)
def autocontrast(input: T, *args: Any, **kwargs: Any) -> T:
def autocontrast(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -102,7 +100,7 @@ def autocontrast(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.equalize_image,
}
)
def equalize(input: T, *args: Any, **kwargs: Any) -> T:
def equalize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -114,7 +112,7 @@ def equalize(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.invert_image,
}
)
def invert(input: T, *args: Any, **kwargs: Any) -> T:
def invert(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -126,7 +124,7 @@ def invert(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.adjust_hue_image,
}
)
def adjust_hue(input: T, *args: Any, **kwargs: Any) -> T:
def adjust_hue(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -138,6 +136,6 @@ def adjust_hue(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.adjust_gamma_image,
}
)
def adjust_gamma(input: T, *args: Any, **kwargs: Any) -> T:
def adjust_gamma(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
from typing import TypeVar, Any, cast
from typing import Any
import PIL.Image
import torch
......@@ -8,8 +8,6 @@ from torchvision.transforms import functional as _F
from ._utils import dispatch
T = TypeVar("T", bound=features._Feature)
@dispatch(
{
......@@ -19,11 +17,11 @@ T = TypeVar("T", bound=features._Feature)
features.BoundingBox: None,
},
)
def horizontal_flip(input: T, *args: Any, **kwargs: Any) -> T:
def horizontal_flip(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
if isinstance(input, features.BoundingBox):
output = K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size)
return cast(T, features.BoundingBox.new_like(input, output))
return features.BoundingBox.new_like(input, output)
raise RuntimeError
......@@ -37,12 +35,12 @@ def horizontal_flip(input: T, *args: Any, **kwargs: Any) -> T:
features.BoundingBox: None,
}
)
def resize(input: T, *args: Any, **kwargs: Any) -> T:
def resize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
if isinstance(input, features.BoundingBox):
size = kwargs.pop("size")
output = K.resize_bounding_box(input, size=size, image_size=input.image_size)
return cast(T, features.BoundingBox.new_like(input, output, image_size=size))
return features.BoundingBox.new_like(input, output, image_size=size)
raise RuntimeError
......@@ -54,7 +52,7 @@ def resize(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.center_crop_image,
}
)
def center_crop(input: T, *args: Any, **kwargs: Any) -> T:
def center_crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -66,7 +64,7 @@ def center_crop(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.resized_crop_image,
}
)
def resized_crop(input: T, *args: Any, **kwargs: Any) -> T:
def resized_crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -78,7 +76,7 @@ def resized_crop(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.affine_image,
}
)
def affine(input: T, *args: Any, **kwargs: Any) -> T:
def affine(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -90,7 +88,7 @@ def affine(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.rotate_image,
}
)
def rotate(input: T, *args: Any, **kwargs: Any) -> T:
def rotate(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -102,7 +100,7 @@ def rotate(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.pad_image,
}
)
def pad(input: T, *args: Any, **kwargs: Any) -> T:
def pad(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -114,7 +112,7 @@ def pad(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.crop_image,
}
)
def crop(input: T, *args: Any, **kwargs: Any) -> T:
def crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -126,7 +124,7 @@ def crop(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.perspective_image,
}
)
def perspective(input: T, *args: Any, **kwargs: Any) -> T:
def perspective(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -138,7 +136,7 @@ def perspective(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.vertical_flip_image,
}
)
def vertical_flip(input: T, *args: Any, **kwargs: Any) -> T:
def vertical_flip(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -150,7 +148,7 @@ def vertical_flip(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.five_crop_image,
}
)
def five_crop(input: T, *args: Any, **kwargs: Any) -> T:
def five_crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -162,6 +160,6 @@ def five_crop(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.ten_crop_image,
}
)
def ten_crop(input: T, *args: Any, **kwargs: Any) -> T:
def ten_crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
from typing import Any
import PIL.Image
import torch
from torchvision.ops import box_convert
from torchvision.prototype import features
from torchvision.prototype.transforms import kernels as K
from torchvision.transforms import functional as _F
from ._utils import dispatch
@dispatch(
{
torch.Tensor: None,
features.BoundingBox: None,
}
)
def convert_format(input: Any, *args: Any, **kwargs: Any) -> Any:
format = kwargs["format"]
if type(input) is torch.Tensor:
old_format = kwargs.get("old_format")
if old_format is None:
raise TypeError("For vanilla tensors the `old_format` needs to be provided.")
return box_convert(input, in_fmt=kwargs["old_format"].name.lower(), out_fmt=format.name.lower())
elif isinstance(input, features.BoundingBox):
output = K.convert_bounding_box_format(input, old_format=input.format, new_format=kwargs["format"])
return features.BoundingBox.new_like(input, output, format=format)
raise RuntimeError
@dispatch(
{
torch.Tensor: None,
PIL.Image.Image: None,
features.Image: None,
}
)
def convert_color_space(input: Any, *args: Any, **kwargs: Any) -> Any:
color_space = kwargs["color_space"]
if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image):
if color_space != features.ColorSpace.GRAYSCALE:
raise ValueError("For vanilla tensors and PIL images only RGB to grayscale is supported")
return _F.rgb_to_grayscale(input)
elif isinstance(input, features.Image):
output = K.convert_color_space(input, old_color_space=input.color_space, new_color_space=color_space)
return features.Image.new_like(input, output, color_space=color_space)
raise RuntimeError
from typing import TypeVar, Any
from typing import Any
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import kernels as K
from torchvision.transforms import functional as _F
from torchvision.transforms.functional_pil import (
get_image_size as _get_image_size_pil,
get_image_num_channels as _get_image_num_channels_pil,
)
from torchvision.transforms.functional_tensor import (
get_image_size as _get_image_size_tensor,
get_image_num_channels as _get_image_num_channels_tensor,
)
from ._utils import dispatch
T = TypeVar("T", bound=features._Feature)
@dispatch(
{
......@@ -17,7 +23,7 @@ T = TypeVar("T", bound=features._Feature)
features.Image: K.normalize_image,
}
)
def normalize(input: T, *args: Any, **kwargs: Any) -> T:
def normalize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
......@@ -29,6 +35,35 @@ def normalize(input: T, *args: Any, **kwargs: Any) -> T:
features.Image: K.gaussian_blur_image,
}
)
def ten_gaussian_blur(input: T, *args: Any, **kwargs: Any) -> T:
def gaussian_blur(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _get_image_size_tensor,
PIL.Image.Image: _get_image_size_pil,
features.Image: None,
features.BoundingBox: None,
}
)
def get_image_size(input: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(input, (features.Image, features.BoundingBox)):
return list(input.image_size)
raise RuntimeError
@dispatch(
{
torch.Tensor: _get_image_num_channels_tensor,
PIL.Image.Image: _get_image_num_channels_pil,
features.Image: None,
}
)
def get_image_num_channels(input: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(input, features.Image):
return input.num_channels
raise RuntimeError
import functools
import inspect
from typing import Any, Optional, Callable, TypeVar, Dict
from typing import Any, Optional, Callable, TypeVar, Mapping, Type
import torch
import torch.overrides
......@@ -9,10 +8,10 @@ from torchvision.prototype import features
F = TypeVar("F", bound=features._Feature)
def dispatch(kernels: Dict[Any, Optional[Callable]]) -> Callable[[Callable[..., F]], Callable[..., F]]:
"""Decorates a function to automatically dispatch to registered kernels based on the call arguments.
class Dispatcher:
"""Wrap a function to automatically dispatch to registered kernels based on the call arguments.
The dispatch function should have this signature
The wrapped function should have this signature
.. code:: python
......@@ -34,7 +33,19 @@ def dispatch(kernels: Dict[Any, Optional[Callable]]) -> Callable[[Callable[...,
TypeError: If the decorated function is called with an input that cannot be dispatched.
"""
def check_kernel(kernel: Any) -> bool:
def __init__(self, fn: Callable, kernels: Mapping[Type, Optional[Callable]]):
self._fn = fn
for feature_type, kernel in kernels.items():
if not self._check_kernel(kernel):
raise TypeError(
f"Kernel for feature type {feature_type.__name__} is not callable with "
f"kernel(input, *args, **kwargs)."
)
self._kernels = kernels
def _check_kernel(self, kernel: Optional[Callable]) -> bool:
if kernel is None:
return True
......@@ -47,43 +58,45 @@ def dispatch(kernels: Dict[Any, Optional[Callable]]) -> Callable[[Callable[...,
return params[0].kind != inspect.Parameter.KEYWORD_ONLY
for feature_type, kernel in kernels.items():
if not check_kernel(kernel):
raise TypeError(
f"Kernel for feature type {feature_type.__name__} is not callable with kernel(input, *args, **kwargs)."
)
def outer_wrapper(dispatch_fn: Callable[..., F]) -> Callable[..., F]:
@functools.wraps(dispatch_fn)
def inner_wrapper(input: F, *args: Any, **kwargs: Any) -> F:
feature_type = type(input)
def _resolve(self, feature_type: Type) -> Optional[Callable]:
try:
return self._kernels[feature_type]
except KeyError:
try:
kernel = kernels[feature_type]
except KeyError:
try:
feature_type, kernel = next(
(feature_type, kernel)
for feature_type, kernel in kernels.items()
if isinstance(input, feature_type)
)
except StopIteration:
raise TypeError(f"No support for {type(input).__name__}") from None
if kernel is None:
output = dispatch_fn(input, *args, **kwargs)
if output is None:
raise RuntimeError(
f"{dispatch_fn.__name__}() did not handle inputs of type {type(input).__name__} "
f"although it was configured to do so."
)
else:
output = kernel(input, *args, **kwargs)
if issubclass(feature_type, features._Feature) and type(output) is torch.Tensor:
output = feature_type.new_like(input, output)
return output
return inner_wrapper
return outer_wrapper
return next(
kernel
for registered_feature_type, kernel in self._kernels.items()
if issubclass(feature_type, registered_feature_type)
)
except StopIteration:
raise TypeError(f"No support for feature type {feature_type.__name__}") from None
def __contains__(self, obj: Any) -> bool:
try:
self._resolve(type(obj))
return True
except TypeError:
return False
def __call__(self, input: Any, *args: Any, **kwargs: Any) -> Any:
kernel = self._resolve(type(input))
if kernel is None:
output = self._fn(input, *args, **kwargs)
if output is None:
raise RuntimeError(
f"{self._fn.__name__}() did not handle inputs of type {type(input).__name__} "
f"although it was configured to do so."
)
else:
output = kernel(input, *args, **kwargs)
if isinstance(input, features._Feature) and type(output) is torch.Tensor:
output = type(input).new_like(input, output)
return output
def dispatch(kernels: Mapping[Type, Optional[Callable]]) -> Callable[[Callable], Dispatcher]:
"""Decorates a function and turns it into a :class:`Dispatcher`."""
return lambda fn: Dispatcher(fn, kernels)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment