Unverified Commit d32bc4ba authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Revamp prototype features and transforms (#5407)

* revamp prototype features (#5283)

* remove decoding from prototype datasets (#5287)

* remove decoder from prototype datasets

* remove unused imports

* cleanup

* fix readme

* use OneHotLabel in SEMEION

* improve voc implementation

* revert unrelated changes

* fix semeion mock data

* fix pcam

* readd functional transforms API to prototype (#5295)

* readd functional transforms

* cleanup

* add missing imports

* remove __torch_function__ dispatch

* readd repr

* readd empty line

* add test for scriptability

* remove function copy

* change import from functional tensor transforms to just functional

* fix import

* fix test

* fix prototype features and functional transforms after review (#5377)

* fix prototype functional transforms after review

* address features review

* make mypy more strict on prototype features

* make mypy more strict for prototype transforms

* fix annotation

* fix kernel tests

* add automatic feature type dispatch to functional transforms (#5323)

* add auto dispatch

* fix missing arguments error message

* remove pil kernel for erase

* automate feature specific parameter detection

* fix typos

* cleanup dispatcher call

* remove __torch_function__ from transform dispatch

* remove auto-generation

* revert unrelated changes

* remove implements decorator

* change register parameter order

* change order of transforms for readability

* add documentation for __torch_function__

* fix mypy

* inline check for support

* refactor kernel registering process

* refactor dispatch to be a regular decorator

* split kernels and dispatchers

* remove sentinels

* replace pass with ...

* appease mypy

* make single kernel dispatchers more concise

* make dispatcher signatures more generic

* make kernel checking more strict

* revert doc changes

* address Franciscos comments

* remove inplace

* rename kernel test module

* fix inplace

* remove special casing for pil and vanilla tensors

* address comments

* update docs

* cleanup features / transforms feature branch (#5406)

* mark candidates for removal

* align signature of resize_bounding_box with corresponding image kernel

* fix documentation of Feature

* remove interpolation mode and antialias option from resize_segmentation_mask

* remove or privatize functionality in features / datasets / transforms
parent f2f490b1
from ._transform import Transform
from ._container import Compose, RandomApply, RandomChoice, RandomOrder # usort: skip
from . import kernels # usort: skip
from . import functional # usort: skip
from .kernels import InterpolationMode # usort: skip
from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop
from ._misc import Identity, Normalize
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
from typing import Any, List
import torch
from torch import nn
from torchvision.prototype.transforms import Transform
class ContainerTransform(nn.Module):
def supports(self, obj: Any) -> bool:
raise NotImplementedError()
def forward(self, *inputs: Any) -> Any:
raise NotImplementedError()
def _make_repr(self, lines: List[str]) -> str:
extra_repr = self.extra_repr()
if extra_repr:
lines = [self.extra_repr(), *lines]
head = f"{type(self).__name__}("
tail = ")"
body = [f" {line.rstrip()}" for line in lines]
return "\n".join([head, *body, tail])
class WrapperTransform(ContainerTransform):
def __init__(self, transform: Transform):
super().__init__()
self._transform = transform
def supports(self, obj: Any) -> bool:
return self._transform.supports(obj)
def __repr__(self) -> str:
return self._make_repr(repr(self._transform).splitlines())
class MultiTransform(ContainerTransform):
def __init__(self, *transforms: Transform) -> None:
super().__init__()
self._transforms = transforms
def supports(self, obj: Any) -> bool:
return all(transform.supports(obj) for transform in self._transforms)
def __repr__(self) -> str:
lines = []
for idx, transform in enumerate(self._transforms):
partial_lines = repr(transform).splitlines()
lines.append(f"({idx:d}): {partial_lines[0]}")
lines.extend(partial_lines[1:])
return self._make_repr(lines)
class Compose(MultiTransform):
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for transform in self._transforms:
sample = transform(sample)
return sample
class RandomApply(WrapperTransform):
def __init__(self, transform: Transform, *, p: float = 0.5) -> None:
super().__init__(transform)
self._p = p
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if float(torch.rand(())) < self._p:
return sample
return self._transform(sample)
def extra_repr(self) -> str:
return f"p={self._p}"
class RandomChoice(MultiTransform):
def forward(self, *inputs: Any) -> Any:
idx = int(torch.randint(len(self._transforms), size=()))
transform = self._transforms[idx]
return transform(*inputs)
class RandomOrder(MultiTransform):
def forward(self, *inputs: Any) -> Any:
for idx in torch.randperm(len(self._transforms)):
transform = self._transforms[idx]
inputs = transform(*inputs)
return inputs
from typing import Any, Dict, Tuple, Union
import torch
from torch.nn.functional import interpolate
from torchvision.prototype.datasets.utils import SampleQuery
from torchvision.prototype.features import BoundingBox, Image, Label
from torchvision.prototype.transforms import Transform
class HorizontalFlip(Transform):
NO_OP_FEATURE_TYPES = {Label}
@staticmethod
def image(input: Image) -> Image:
return Image(input.flip((-1,)), like=input)
@staticmethod
def bounding_box(input: BoundingBox) -> BoundingBox:
x, y, w, h = input.convert("xywh").to_parts()
x = input.image_size[1] - (x + w)
return BoundingBox.from_parts(x, y, w, h, like=input, format="xywh").convert(input.format)
class Resize(Transform):
NO_OP_FEATURE_TYPES = {Label}
def __init__(
self,
size: Union[int, Tuple[int, int]],
*,
interpolation_mode: str = "nearest",
) -> None:
super().__init__()
self.size = (size, size) if isinstance(size, int) else size
self.interpolation_mode = interpolation_mode
def get_params(self, sample: Any) -> Dict[str, Any]:
return dict(size=self.size, interpolation_mode=self.interpolation_mode)
@staticmethod
def image(input: Image, *, size: Tuple[int, int], interpolation_mode: str = "nearest") -> Image:
return Image(interpolate(input.unsqueeze(0), size, mode=interpolation_mode).squeeze(0), like=input)
@staticmethod
def bounding_box(input: BoundingBox, *, size: Tuple[int, int], **_: Any) -> BoundingBox:
old_height, old_width = input.image_size
new_height, new_width = size
height_scale = new_height / old_height
width_scale = new_width / old_width
old_x1, old_y1, old_x2, old_y2 = input.convert("xyxy").to_parts()
new_x1 = old_x1 * width_scale
new_y1 = old_y1 * height_scale
new_x2 = old_x2 * width_scale
new_y2 = old_y2 * height_scale
return BoundingBox.from_parts(
new_x1, new_y1, new_x2, new_y2, like=input, format="xyxy", image_size=size
).convert(input.format)
def extra_repr(self) -> str:
extra_repr = f"size={self.size}"
if self.interpolation_mode != "bilinear":
extra_repr += f", interpolation_mode={self.interpolation_mode}"
return extra_repr
class RandomResize(Transform, wraps=Resize):
def __init__(self, min_size: Union[int, Tuple[int, int]], max_size: Union[int, Tuple[int, int]]) -> None:
super().__init__()
self.min_size = (min_size, min_size) if isinstance(min_size, int) else min_size
self.max_size = (max_size, max_size) if isinstance(max_size, int) else max_size
def get_params(self, sample: Any) -> Dict[str, Any]:
min_height, min_width = self.min_size
max_height, max_width = self.max_size
height = int(torch.randint(min_height, max_height + 1, size=()))
width = int(torch.randint(min_width, max_width + 1, size=()))
return dict(size=(height, width))
def extra_repr(self) -> str:
return f"min_size={self.min_size}, max_size={self.max_size}"
class Crop(Transform):
NO_OP_FEATURE_TYPES = {BoundingBox, Label}
def __init__(self, crop_box: BoundingBox) -> None:
super().__init__()
self.crop_box = crop_box.convert("xyxy")
def get_params(self, sample: Any) -> Dict[str, Any]:
return dict(crop_box=self.crop_box)
@staticmethod
def image(input: Image, *, crop_box: BoundingBox) -> Image:
# FIXME: pad input in case it is smaller than crop_box
x1, y1, x2, y2 = crop_box.convert("xyxy").to_parts()
return Image(input[..., y1 : y2 + 1, x1 : x2 + 1], like=input) # type: ignore[misc]
class CenterCrop(Transform, wraps=Crop):
def __init__(self, crop_size: Union[int, Tuple[int, int]]) -> None:
super().__init__()
self.crop_size = (crop_size, crop_size) if isinstance(crop_size, int) else crop_size
def get_params(self, sample: Any) -> Dict[str, Any]:
image_size = SampleQuery(sample).image_size()
image_height, image_width = image_size
cx = image_width // 2
cy = image_height // 2
h, w = self.crop_size
crop_box = BoundingBox.from_parts(cx, cy, w, h, image_size=image_size, format="cxcywh")
return dict(crop_box=crop_box)
def extra_repr(self) -> str:
return f"crop_size={self.crop_size}"
class RandomCrop(Transform, wraps=Crop):
def __init__(self, crop_size: Union[int, Tuple[int, int]]) -> None:
super().__init__()
self.crop_size = (crop_size, crop_size) if isinstance(crop_size, int) else crop_size
def get_params(self, sample: Any) -> Dict[str, Any]:
image_size = SampleQuery(sample).image_size()
image_height, image_width = image_size
crop_height, crop_width = self.crop_size
x = torch.randint(0, image_width - crop_width + 1, size=()) if crop_width < image_width else 0
y = torch.randint(0, image_height - crop_height + 1, size=()) if crop_height < image_height else 0
crop_box = BoundingBox.from_parts(x, y, crop_width, crop_height, image_size=image_size, format="xywh")
return dict(crop_box=crop_box)
def extra_repr(self) -> str:
return f"crop_size={self.crop_size}"
from typing import Any, Dict, Sequence
import torch
from torchvision.prototype.features import Image, BoundingBox, Label
from torchvision.prototype.transforms import Transform
class Identity(Transform):
"""Identity transform that supports all built-in :class:`~torchvision.prototype.features.Feature`'s."""
def __init__(self):
super().__init__()
for feature_type in self._BUILTIN_FEATURE_TYPES:
self.register_feature_transform(feature_type, lambda input, **params: input)
class Normalize(Transform):
NO_OP_FEATURE_TYPES = {BoundingBox, Label}
def __init__(self, mean: Sequence[float], std: Sequence[float]):
super().__init__()
self.mean = mean
self.std = std
def get_params(self, sample: Any) -> Dict[str, Any]:
return dict(mean=self.mean, std=self.std)
@staticmethod
def _channel_stats_to_tensor(stats: Sequence[float], *, like: torch.Tensor) -> torch.Tensor:
return torch.as_tensor(stats, device=like.device, dtype=like.dtype).view(-1, 1, 1)
@staticmethod
def image(input: Image, *, mean: Sequence[float], std: Sequence[float]) -> Image:
mean_t = Normalize._channel_stats_to_tensor(mean, like=input)
std_t = Normalize._channel_stats_to_tensor(std, like=input)
return Image((input - mean_t) / std_t, like=input)
def extra_repr(self) -> str:
return f"mean={tuple(self.mean)}, std={tuple(self.std)}"
import collections.abc
import inspect
import re
from typing import Any, Callable, Dict, Optional, Type, Union, cast, Set, Collection
import torch
from torch import nn
from torchvision.prototype import features
from torchvision.prototype.utils._internal import add_suggestion
class Transform(nn.Module):
"""Base class for transforms.
A transform operates on a full sample at once, which might be a nested container of elements to transform. The
non-container elements of the sample will be dispatched to feature transforms based on their type in case it is
supported by the transform. Each transform needs to define at least one feature transform, which is canonical done
as static method:
.. code-block::
class ImageIdentity(Transform):
@staticmethod
def image(input):
return input
To achieve correct results for a complete sample, each transform should implement feature transforms for every
:class:`Feature` it can handle:
.. code-block::
class Identity(Transform):
@staticmethod
def image(input):
return input
@staticmethod
def bounding_box(input):
return input
...
If the name of a static method in camel-case matches the name of a :class:`Feature`, the feature transform is
auto-registered. Supported pairs are:
+----------------+----------------+
| method name | `Feature` |
+================+================+
| `image` | `Image` |
+----------------+----------------+
| `bounding_box` | `BoundingBox` |
+----------------+----------------+
| `label` | `Label` |
+----------------+----------------+
If you don't want to stick to this scheme, you can disable the auto-registration and perform it manually:
.. code-block::
def my_image_transform(input):
...
class MyTransform(Transform, auto_register=False):
def __init__(self):
super().__init__()
self.register_feature_transform(Image, my_image_transform)
self.register_feature_transform(BoundingBox, self.my_bounding_box_transform)
@staticmethod
def my_bounding_box_transform(input):
...
In any case, the registration will assert that the feature transform can be invoked with
``feature_transform(input, **params)``.
.. warning::
Feature transforms are **registered on the class and not on the instance**. This means you cannot have two
instances of the same :class:`Transform` with different feature transforms.
If the feature transforms needs additional parameters, you need to
overwrite the :meth:`~Transform.get_params` method. It needs to return the parameter dictionary that will be
unpacked and its contents passed to each feature transform:
.. code-block::
class Rotate(Transform):
def __init__(self, degrees):
super().__init__()
self.degrees = degrees
def get_params(self, sample):
return dict(degrees=self.degrees)
def image(input, *, degrees):
...
The :meth:`~Transform.get_params` method will be invoked once per sample. Thus, in case of randomly sampled
parameters they will be the same for all features of the whole sample.
.. code-block::
class RandomRotate(Transform)
def __init__(self, range):
super().__init__()
self._dist = torch.distributions.Uniform(range)
def get_params(self, sample):
return dict(degrees=self._dist.sample().item())
@staticmethod
def image(input, *, degrees):
...
In case the sampling depends on one or more features at runtime, the complete ``sample`` gets passed to the
:meth:`Transform.get_params` method. Derivative transforms that only changes the parameter sampling, but the
feature transformations are identical, can simply wrap the transform they dispatch to:
.. code-block::
class RandomRotate(Transform, wraps=Rotate):
def get_params(self, sample):
return dict(degrees=float(torch.rand(())) * 30.0)
To transform a sample, you simply call an instance of the transform with it:
.. code-block::
transform = MyTransform()
sample = dict(input=Image(torch.tensor(...)), target=BoundingBox(torch.tensor(...)), ...)
transformed_sample = transform(sample)
.. note::
To use a :class:`Transform` with a dataset, simply use it as map:
.. code-block::
torchvision.datasets.load(...).map(MyTransform())
"""
_BUILTIN_FEATURE_TYPES = (
features.BoundingBox,
features.Image,
features.Label,
)
_FEATURE_NAME_MAP = {
"_".join([part.lower() for part in re.findall("[A-Z][^A-Z]*", feature_type.__name__)]): feature_type
for feature_type in _BUILTIN_FEATURE_TYPES
}
_feature_transforms: Dict[Type[features.Feature], Callable]
NO_OP_FEATURE_TYPES: Collection[Type[features.Feature]] = ()
def __init_subclass__(
cls, *, wraps: Optional[Type["Transform"]] = None, auto_register: bool = True, verbose: bool = False
):
cls._feature_transforms = {} if wraps is None else wraps._feature_transforms.copy()
if wraps:
cls.NO_OP_FEATURE_TYPES = wraps.NO_OP_FEATURE_TYPES
if auto_register:
cls._auto_register(verbose=verbose)
@staticmethod
def _has_allowed_signature(feature_transform: Callable) -> bool:
"""Checks if ``feature_transform`` can be invoked with ``feature_transform(input, **params)``"""
parameters = tuple(inspect.signature(feature_transform).parameters.values())
if not parameters:
return False
elif len(parameters) == 1:
return parameters[0].kind != inspect.Parameter.KEYWORD_ONLY
else:
return parameters[1].kind != inspect.Parameter.POSITIONAL_ONLY
@classmethod
def register_feature_transform(cls, feature_type: Type[features.Feature], transform: Callable) -> None:
"""Registers a transform for given feature on the class.
If a transform object is called or :meth:`Transform.apply` is invoked, inputs are dispatched to the registered
transforms based on their type.
Args:
feature_type: Feature type the transformation is registered for.
transform: Feature transformation.
Raises:
TypeError: If ``transform`` cannot be invoked with ``transform(input, **params)``.
"""
if not cls._has_allowed_signature(transform):
raise TypeError("Feature transform cannot be invoked with transform(input, **params)")
cls._feature_transforms[feature_type] = transform
@classmethod
def _auto_register(cls, *, verbose: bool = False) -> None:
"""Auto-registers methods on the class as feature transforms if they meet the following criteria:
1. They are static.
2. They can be invoked with `cls.feature_transform(input, **params)`.
3. They are public.
4. Their name in camel case matches the name of a builtin feature, e.g. 'bounding_box' and 'BoundingBox'.
The name from 4. determines for which feature the method is registered.
.. note::
The ``auto_register`` and ``verbose`` flags need to be passed as keyword arguments to the class:
.. code-block::
class MyTransform(Transform, auto_register=True, verbose=True):
...
Args:
verbose: If ``True``, prints to STDOUT which methods were registered or why a method was not registered
"""
for name, value in inspect.getmembers(cls):
# check if attribute is a static method and was defined in the subclass
# TODO: this needs to be revisited to allow subclassing of custom transforms
if not (name in cls.__dict__ and inspect.isfunction(value)):
continue
not_registered_prefix = f"{cls.__name__}.{name}() was not registered as feature transform, because"
if not cls._has_allowed_signature(value):
if verbose:
print(f"{not_registered_prefix} it cannot be invoked with {name}(input, **params).")
continue
if name.startswith("_"):
if verbose:
print(f"{not_registered_prefix} it is private.")
continue
try:
feature_type = cls._FEATURE_NAME_MAP[name]
except KeyError:
if verbose:
print(
add_suggestion(
f"{not_registered_prefix} its name doesn't match any known feature type.",
word=name,
possibilities=cls._FEATURE_NAME_MAP.keys(),
close_match_hint=lambda close_match: (
f"Did you mean to name it '{close_match}' "
f"to be registered for type '{cls._FEATURE_NAME_MAP[close_match]}'?"
),
)
)
continue
cls.register_feature_transform(feature_type, value)
if verbose:
print(
f"{cls.__name__}.{name}() was registered as feature transform for type '{feature_type.__name__}'."
)
@classmethod
def from_callable(
cls,
feature_transform: Union[Callable, Dict[Type[features.Feature], Callable]],
*,
name: str = "FromCallable",
get_params: Optional[Union[Dict[str, Any], Callable[[Any], Dict[str, Any]]]] = None,
) -> "Transform":
"""Creates a new transform from a callable.
Args:
feature_transform: Feature transform that will be registered to handle :class:`Image`'s. Can be passed as
dictionary in which case each key-value-pair is needs to consists of a ``Feature`` type and the
corresponding transform.
name: Name of the transform.
get_params: Parameter dictionary ``params`` that will be passed to ``feature_transform(input, **params)``.
Can be passed as callable in which case it will be called with the transform instance (``self``) and
the input of the transform.
Raises:
TypeError: If ``feature_transform`` cannot be invoked with ``feature_transform(input, **params)``.
"""
if get_params is None:
get_params = dict()
attributes = dict(
get_params=get_params if callable(get_params) else lambda self, sample: get_params, # type: ignore[misc]
)
transform_cls = cast(Type[Transform], type(name, (cls,), attributes))
if callable(feature_transform):
feature_transform = {features.Image: feature_transform}
for feature_type, transform in feature_transform.items():
transform_cls.register_feature_transform(feature_type, transform)
return transform_cls()
@classmethod
def supported_feature_types(cls) -> Set[Type[features.Feature]]:
return set(cls._feature_transforms.keys())
@classmethod
def supports(cls, obj: Any) -> bool:
"""Checks if object or type is supported.
Args:
obj: Object or type.
"""
# TODO: should this handle containers?
feature_type = obj if isinstance(obj, type) else type(obj)
return feature_type is torch.Tensor or feature_type in cls.supported_feature_types()
@classmethod
def transform(cls, input: Union[torch.Tensor, features.Feature], **params: Any) -> torch.Tensor:
"""Applies the registered feature transform to the input based on its type.
This can be uses as feature type generic functional interface:
.. code-block::
transform = Rotate.transform
transformed_image = transform(Image(torch.tensor(...)), degrees=30.0)
transformed_bbox = transform(BoundingBox(torch.tensor(...)), degrees=-10.0)
Args:
input: ``input`` in ``feature_transform(input, **params)``
**params: Parameter dictionary ``params`` in ``feature_transform(input, **params)``.
Returns:
Transformed input.
"""
feature_type = type(input)
if not cls.supports(feature_type):
raise TypeError(f"{cls.__name__}() is not able to handle inputs of type {feature_type}.")
if feature_type is torch.Tensor:
# To keep BC, we treat all regular torch.Tensor's as images
feature_type = features.Image
input = feature_type(input)
feature_type = cast(Type[features.Feature], feature_type)
feature_transform = cls._feature_transforms[feature_type]
output = feature_transform(input, **params)
if type(output) is torch.Tensor:
output = feature_type(output, like=input)
return output
def _transform_recursively(self, sample: Any, *, params: Dict[str, Any]) -> Any:
"""Recurses through a sample and invokes :meth:`Transform.transform` on non-container elements.
If an element is not supported by the transform, it is returned untransformed.
Args:
sample: Sample.
params: Parameter dictionary ``params`` that will be passed to ``feature_transform(input, **params)``.
"""
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
# "a" == "a"[0][0]...
if isinstance(sample, collections.abc.Sequence) and not isinstance(sample, str):
return [self._transform_recursively(item, params=params) for item in sample]
elif isinstance(sample, collections.abc.Mapping):
return {name: self._transform_recursively(item, params=params) for name, item in sample.items()}
else:
feature_type = type(sample)
if not self.supports(feature_type):
if (
not issubclass(feature_type, features.Feature)
# issubclass is not a strict check, but also allows the type checked against. Thus, we need to
# check it separately
or feature_type is features.Feature
or feature_type in self.NO_OP_FEATURE_TYPES
):
return sample
raise TypeError(
f"{type(self).__name__}() is not able to handle inputs of type {feature_type}. "
f"If you want it to be a no-op, add the feature type to {type(self).__name__}.NO_OP_FEATURE_TYPES."
)
return self.transform(cast(Union[torch.Tensor, features.Feature], sample), **params)
def get_params(self, sample: Any) -> Dict[str, Any]:
"""Returns the parameter dictionary used to transform the current sample.
.. note::
Since ``sample`` might be a nested container, it is recommended to use the
:class:`torchvision.datasets.utils.Query` class if you need to extract information from it.
Args:
sample: Current sample.
Returns:
Parameter dictionary ``params`` in ``feature_transform(input, **params)``.
"""
return dict()
def forward(
self,
*inputs: Any,
params: Optional[Dict[str, Any]] = None,
) -> Any:
if not self._feature_transforms:
raise RuntimeError(f"{type(self).__name__}() has no registered feature transform.")
sample = inputs if len(inputs) > 1 else inputs[0]
if params is None:
params = self.get_params(sample)
return self._transform_recursively(sample, params=params)
from ._augment import erase, mixup, cutmix
from ._color import (
adjust_brightness,
adjust_contrast,
adjust_saturation,
adjust_sharpness,
posterize,
solarize,
autocontrast,
equalize,
invert,
)
from ._geometry import horizontal_flip, resize, center_crop, resized_crop, affine, rotate
from ._misc import normalize
from typing import TypeVar, Any
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import kernels as K
from torchvision.transforms import functional as _F
from ._utils import dispatch
T = TypeVar("T", bound=features._Feature)
@dispatch(
{
torch.Tensor: _F.erase,
features.Image: K.erase_image,
}
)
def erase(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
...
@dispatch(
{
features.Image: K.mixup_image,
features.OneHotLabel: K.mixup_one_hot_label,
}
)
def mixup(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
...
@dispatch(
{
features.Image: K.cutmix_image,
features.OneHotLabel: K.cutmix_one_hot_label,
}
)
def cutmix(input: T, *args: Any, **kwargs: Any) -> T:
"""Perform the CutMix operation as introduced in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" <https://arxiv.org/abs/1905.04899>`_.
Dispatch to the corresponding kernels happens according to this table:
.. table::
:widths: 30 70
==================================================== ================================================================
:class:`~torchvision.prototype.features.Image` :func:`~torch.prototype.transforms.kernels.cutmix_image`
:class:`~torchvision.prototype.features.OneHotLabel` :func:`~torch.prototype.transforms.kernels.cutmix_one_hot_label`
==================================================== ================================================================
Please refer to the kernel documentations for a detailed explanation of the functionality and parameters.
"""
...
from typing import TypeVar, Any
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import kernels as K
from torchvision.transforms import functional as _F
from ._utils import dispatch
T = TypeVar("T", bound=features._Feature)
@dispatch(
{
torch.Tensor: _F.adjust_brightness,
PIL.Image.Image: _F.adjust_brightness,
features.Image: K.adjust_brightness_image,
}
)
def adjust_brightness(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
...
@dispatch(
{
torch.Tensor: _F.adjust_saturation,
PIL.Image.Image: _F.adjust_saturation,
features.Image: K.adjust_saturation_image,
}
)
def adjust_saturation(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
...
@dispatch(
{
torch.Tensor: _F.adjust_contrast,
PIL.Image.Image: _F.adjust_contrast,
features.Image: K.adjust_contrast_image,
}
)
def adjust_contrast(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
...
@dispatch(
{
torch.Tensor: _F.adjust_sharpness,
PIL.Image.Image: _F.adjust_sharpness,
features.Image: K.adjust_sharpness_image,
}
)
def adjust_sharpness(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
...
@dispatch(
{
torch.Tensor: _F.posterize,
PIL.Image.Image: _F.posterize,
features.Image: K.posterize_image,
}
)
def posterize(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
...
@dispatch(
{
torch.Tensor: _F.solarize,
PIL.Image.Image: _F.solarize,
features.Image: K.solarize_image,
}
)
def solarize(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
...
@dispatch(
{
torch.Tensor: _F.autocontrast,
PIL.Image.Image: _F.autocontrast,
features.Image: K.autocontrast_image,
}
)
def autocontrast(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
...
@dispatch(
{
torch.Tensor: _F.equalize,
PIL.Image.Image: _F.equalize,
features.Image: K.equalize_image,
}
)
def equalize(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
...
@dispatch(
{
torch.Tensor: _F.invert,
PIL.Image.Image: _F.invert,
features.Image: K.invert_image,
}
)
def invert(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
...
from typing import TypeVar, Any, cast
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import kernels as K
from torchvision.transforms import functional as _F
from ._utils import dispatch
T = TypeVar("T", bound=features._Feature)
@dispatch(
{
torch.Tensor: _F.hflip,
PIL.Image.Image: _F.hflip,
features.Image: K.horizontal_flip_image,
features.BoundingBox: None,
},
)
def horizontal_flip(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
if isinstance(input, features.BoundingBox):
output = K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size)
return cast(T, features.BoundingBox.new_like(input, output))
raise RuntimeError
@dispatch(
{
torch.Tensor: _F.resize,
PIL.Image.Image: _F.resize,
features.Image: K.resize_image,
features.SegmentationMask: K.resize_segmentation_mask,
features.BoundingBox: None,
}
)
def resize(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
if isinstance(input, features.BoundingBox):
size = kwargs.pop("size")
output = K.resize_bounding_box(input, size=size, image_size=input.image_size)
return cast(T, features.BoundingBox.new_like(input, output, image_size=size))
raise RuntimeError
@dispatch(
{
torch.Tensor: _F.center_crop,
PIL.Image.Image: _F.center_crop,
features.Image: K.center_crop_image,
}
)
def center_crop(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
...
@dispatch(
{
torch.Tensor: _F.resized_crop,
PIL.Image.Image: _F.resized_crop,
features.Image: K.resized_crop_image,
}
)
def resized_crop(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
...
@dispatch(
{
torch.Tensor: _F.affine,
PIL.Image.Image: _F.affine,
features.Image: K.affine_image,
}
)
def affine(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
...
@dispatch(
{
torch.Tensor: _F.rotate,
PIL.Image.Image: _F.rotate,
features.Image: K.rotate_image,
}
)
def rotate(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
...
from typing import TypeVar, Any
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import kernels as K
from torchvision.transforms import functional as _F
from ._utils import dispatch
T = TypeVar("T", bound=features._Feature)
@dispatch(
{
torch.Tensor: _F.normalize,
features.Image: K.normalize_image,
}
)
def normalize(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
...
import functools
import inspect
from typing import Any, Optional, Callable, TypeVar, Dict
import torch
import torch.overrides
from torchvision.prototype import features
F = TypeVar("F", bound=features._Feature)
def dispatch(kernels: Dict[Any, Optional[Callable]]) -> Callable[[Callable[..., F]], Callable[..., F]]:
"""Decorates a function to automatically dispatch to registered kernels based on the call arguments.
The dispatch function should have this signature
.. code:: python
@dispatch(
...
)
def dispatch_fn(input, *args, **kwargs):
...
where ``input`` is used to determine which kernel to dispatch to.
Args:
kernels: Dictionary with types as keys that maps to a kernel to call. The resolution order is checking for
exact type matches first and if none is found falls back to checking for subclasses. If a value is
``None``, the decorated function is called.
Raises:
TypeError: If any value in ``kernels`` is not callable with ``kernel(input, *args, **kwargs)``.
TypeError: If the decorated function is called with an input that cannot be dispatched.
"""
def check_kernel(kernel: Any) -> bool:
if kernel is None:
return True
if not callable(kernel):
return False
params = list(inspect.signature(kernel).parameters.values())
if not params:
return False
return params[0].kind != inspect.Parameter.KEYWORD_ONLY
for feature_type, kernel in kernels.items():
if not check_kernel(kernel):
raise TypeError(
f"Kernel for feature type {feature_type.__name__} is not callable with kernel(input, *args, **kwargs)."
)
def outer_wrapper(dispatch_fn: Callable[..., F]) -> Callable[..., F]:
@functools.wraps(dispatch_fn)
def inner_wrapper(input: F, *args: Any, **kwargs: Any) -> F:
feature_type = type(input)
try:
kernel = kernels[feature_type]
except KeyError:
try:
feature_type, kernel = next(
(feature_type, kernel)
for feature_type, kernel in kernels.items()
if isinstance(input, feature_type)
)
except StopIteration:
raise TypeError(f"No support for {type(input).__name__}") from None
if kernel is None:
output = dispatch_fn(input, *args, **kwargs)
if output is None:
raise RuntimeError(
f"{dispatch_fn.__name__}() did not handle inputs of type {type(input).__name__} "
f"although it was configured to do so."
)
else:
output = kernel(input, *args, **kwargs)
if issubclass(feature_type, features._Feature) and type(output) is torch.Tensor:
output = feature_type.new_like(input, output)
return output
return inner_wrapper
return outer_wrapper
from torchvision.transforms import InterpolationMode # usort: skip
from ._meta_conversion import convert_bounding_box_format, convert_color_space # usort: skip
from ._augment import (
erase_image,
mixup_image,
mixup_one_hot_label,
cutmix_image,
cutmix_one_hot_label,
)
from ._color import (
adjust_brightness_image,
adjust_contrast_image,
adjust_saturation_image,
adjust_sharpness_image,
posterize_image,
solarize_image,
autocontrast_image,
equalize_image,
invert_image,
)
from ._geometry import (
horizontal_flip_bounding_box,
horizontal_flip_image,
resize_bounding_box,
resize_image,
resize_segmentation_mask,
center_crop_image,
resized_crop_image,
affine_image,
rotate_image,
)
from ._misc import normalize_image
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
from typing import Tuple
import torch
from torchvision.transforms import functional as _F
erase_image = _F.erase
def _mixup(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor:
input = input.clone()
return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam))
def mixup_image(image_batch: torch.Tensor, *, lam: float) -> torch.Tensor:
if image_batch.ndim < 4:
raise ValueError("Need a batch of images")
return _mixup(image_batch, -4, lam)
def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float) -> torch.Tensor:
if one_hot_label_batch.ndim < 2:
raise ValueError("Need a batch of one hot labels")
return _mixup(one_hot_label_batch, -2, lam)
def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int]) -> torch.Tensor:
if image_batch.ndim < 4:
raise ValueError("Need a batch of images")
x1, y1, x2, y2 = box
image_rolled = image_batch.roll(1, -4)
image_batch = image_batch.clone()
image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
return image_batch
def cutmix_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam_adjusted: float) -> torch.Tensor:
if one_hot_label_batch.ndim < 2:
raise ValueError("Need a batch of one hot labels")
return _mixup(one_hot_label_batch, -2, lam_adjusted)
from torchvision.transforms import functional as _F
adjust_brightness_image = _F.adjust_brightness
adjust_saturation_image = _F.adjust_saturation
adjust_contrast_image = _F.adjust_contrast
adjust_sharpness_image = _F.adjust_sharpness
posterize_image = _F.posterize
solarize_image = _F.solarize
autocontrast_image = _F.autocontrast
equalize_image = _F.equalize
invert_image = _F.invert
from typing import Tuple, List, Optional, TypeVar
import torch
from torchvision.prototype import features
from torchvision.transforms import functional as _F, InterpolationMode
from ._meta_conversion import convert_bounding_box_format
T = TypeVar("T", bound=features._Feature)
horizontal_flip_image = _F.hflip
def horizontal_flip_bounding_box(
bounding_box: torch.Tensor, *, format: features.BoundingBoxFormat, image_size: Tuple[int, int]
) -> torch.Tensor:
shape = bounding_box.shape
bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [2, 0]]
return convert_bounding_box_format(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
).view(shape)
def resize_image(
image: torch.Tensor,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> torch.Tensor:
new_height, new_width = size
num_channels, old_height, old_width = image.shape[-3:]
batch_shape = image.shape[:-3]
return _F.resize(
image.reshape((-1, num_channels, old_height, old_width)),
size=size,
interpolation=interpolation,
max_size=max_size,
antialias=antialias,
).reshape(batch_shape + (num_channels, new_height, new_width))
def resize_segmentation_mask(
segmentation_mask: torch.Tensor,
size: List[int],
max_size: Optional[int] = None,
) -> torch.Tensor:
return resize_image(segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
# TODO: handle max_size
def resize_bounding_box(bounding_box: torch.Tensor, *, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor:
old_height, old_width = image_size
new_height, new_width = size
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape)
center_crop_image = _F.center_crop
resized_crop_image = _F.resized_crop
affine_image = _F.affine
rotate_image = _F.rotate
This diff is collapsed.
from torchvision.transforms import functional as _F
normalize_image = _F.normalize
import unittest.mock
from typing import Dict, Any, Tuple, cast
import numpy as np
import PIL.Image
import torch
from torch.nn.functional import one_hot
from torchvision.io.video import read_video
from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer
def decode_image_with_pil(encoded_image: torch.Tensor) -> torch.Tensor:
image = torch.as_tensor(np.array(PIL.Image.open(ReadOnlyTensorBuffer(encoded_image)), copy=True))
if image.ndim == 2:
image = image.unsqueeze(2)
return image.permute(2, 0, 1)
def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
with unittest.mock.patch("torchvision.io.video.os.path.exists", return_value=True):
return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type]
def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor:
return cast(torch.Tensor, one_hot(label, num_classes=num_categories))
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment