Unverified Commit 7251769f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Transforms without dispatcher (#5421)



* add prototype transforms that don't need dispatchers

* cleanup

* remove legacy_transform decorator

* remove legacy classes

* remove explicit param passing

* streamline extra_repr

* remove obsolete ._supports() method

* cleanup

* remove Query

* cleanup

* fix tests

* kernels -> functional

* move image size and num channels extraction to functional

* extend legacy function to extract image size and num channels

* implement dispatching for auto augment

* fix auto augment dispatch

* revert some naming changes

* remove ability to pass params to autoaugment

* fix legacy image size extraction

* align prototype.transforms.functional with transforms.functional

* cleanup

* fix image size and channels extraction

* fix affine and rotate

* revert image size to (width, height)

* Minor corrections
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent f15ba56f
import inspect
from typing import Any, Optional, Callable, TypeVar, Mapping, Type
from typing import Tuple, Union, cast
import PIL.Image
import torch
import torch.overrides
from torchvision.prototype import features
F = TypeVar("F", bound=features._Feature)
class Dispatcher:
"""Wrap a function to automatically dispatch to registered kernels based on the call arguments.
The wrapped function should have this signature
.. code:: python
@dispatch(
...
)
def dispatch_fn(input, *args, **kwargs):
...
where ``input`` is used to determine which kernel to dispatch to.
Args:
kernels: Dictionary with types as keys that maps to a kernel to call. The resolution order is checking for
exact type matches first and if none is found falls back to checking for subclasses. If a value is
``None``, the decorated function is called.
Raises:
TypeError: If any value in ``kernels`` is not callable with ``kernel(input, *args, **kwargs)``.
TypeError: If the decorated function is called with an input that cannot be dispatched.
"""
def __init__(self, fn: Callable, kernels: Mapping[Type, Optional[Callable]]):
self._fn = fn
for feature_type, kernel in kernels.items():
if not self._check_kernel(kernel):
raise TypeError(
f"Kernel for feature type {feature_type.__name__} is not callable with "
f"kernel(input, *args, **kwargs)."
)
self._kernels = kernels
def _check_kernel(self, kernel: Optional[Callable]) -> bool:
if kernel is None:
return True
if not callable(kernel):
return False
params = list(inspect.signature(kernel).parameters.values())
if not params:
return False
return params[0].kind != inspect.Parameter.KEYWORD_ONLY
def _resolve(self, feature_type: Type) -> Optional[Callable]:
try:
return self._kernels[feature_type]
except KeyError:
try:
return next(
kernel
for registered_feature_type, kernel in self._kernels.items()
if issubclass(feature_type, registered_feature_type)
)
except StopIteration:
raise TypeError(f"No support for feature type {feature_type.__name__}") from None
def __contains__(self, obj: Any) -> bool:
try:
self._resolve(type(obj))
return True
except TypeError:
return False
def __call__(self, input: Any, *args: Any, **kwargs: Any) -> Any:
kernel = self._resolve(type(input))
if kernel is None:
output = self._fn(input, *args, **kwargs)
if output is None:
raise RuntimeError(
f"{self._fn.__name__}() did not handle inputs of type {type(input).__name__} "
f"although it was configured to do so."
)
else:
output = kernel(input, *args, **kwargs)
if isinstance(input, features._Feature) and type(output) is torch.Tensor:
output = type(input).new_like(input, output)
return output
def dispatch(kernels: Mapping[Type, Optional[Callable]]) -> Callable[[Callable], Dispatcher]:
"""Decorates a function and turns it into a :class:`Dispatcher`."""
return lambda fn: Dispatcher(fn, kernels)
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int]:
if isinstance(image, features.Image):
height, width = image.image_size
return width, height
elif isinstance(image, torch.Tensor):
return cast(Tuple[int, int], tuple(_FT.get_image_size(image)))
if isinstance(image, PIL.Image.Image):
return cast(Tuple[int, int], tuple(_FP.get_image_size(image)))
else:
raise TypeError(f"unable to get image size from object of type {type(image).__name__}")
def get_image_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int:
if isinstance(image, features.Image):
return image.num_channels
elif isinstance(image, torch.Tensor):
return _FT.get_image_num_channels(image)
if isinstance(image, PIL.Image.Image):
return cast(int, _FP.get_image_num_channels(image))
else:
raise TypeError(f"unable to get num channels from object of type {type(image).__name__}")
from torchvision.transforms import InterpolationMode # usort: skip
from ._meta_conversion import convert_bounding_box_format, convert_color_space # usort: skip
from ._augment import (
erase_image,
mixup_image,
mixup_one_hot_label,
cutmix_image,
cutmix_one_hot_label,
)
from ._color import (
adjust_brightness_image,
adjust_contrast_image,
adjust_saturation_image,
adjust_sharpness_image,
posterize_image,
solarize_image,
autocontrast_image,
equalize_image,
invert_image,
adjust_hue_image,
adjust_gamma_image,
)
from ._geometry import (
horizontal_flip_bounding_box,
horizontal_flip_image,
resize_bounding_box,
resize_image,
resize_segmentation_mask,
center_crop_image,
resized_crop_image,
affine_image,
rotate_image,
pad_image,
crop_image,
perspective_image,
vertical_flip_image,
five_crop_image,
ten_crop_image,
)
from ._misc import normalize_image, gaussian_blur_image
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
from typing import Tuple
import torch
from torchvision.transforms import functional as _F
erase_image = _F.erase
def _mixup(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor:
input = input.clone()
return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam))
def mixup_image(image_batch: torch.Tensor, *, lam: float) -> torch.Tensor:
if image_batch.ndim < 4:
raise ValueError("Need a batch of images")
return _mixup(image_batch, -4, lam)
def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float) -> torch.Tensor:
if one_hot_label_batch.ndim < 2:
raise ValueError("Need a batch of one hot labels")
return _mixup(one_hot_label_batch, -2, lam)
def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int]) -> torch.Tensor:
if image_batch.ndim < 4:
raise ValueError("Need a batch of images")
x1, y1, x2, y2 = box
image_rolled = image_batch.roll(1, -4)
image_batch = image_batch.clone()
image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
return image_batch
def cutmix_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam_adjusted: float) -> torch.Tensor:
if one_hot_label_batch.ndim < 2:
raise ValueError("Need a batch of one hot labels")
return _mixup(one_hot_label_batch, -2, lam_adjusted)
from torchvision.transforms import functional as _F
adjust_brightness_image = _F.adjust_brightness
adjust_saturation_image = _F.adjust_saturation
adjust_contrast_image = _F.adjust_contrast
adjust_sharpness_image = _F.adjust_sharpness
posterize_image = _F.posterize
solarize_image = _F.solarize
autocontrast_image = _F.autocontrast
equalize_image = _F.equalize
invert_image = _F.invert
adjust_hue_image = _F.adjust_hue
adjust_gamma_image = _F.adjust_gamma
from typing import Tuple, List, Optional, TypeVar
import torch
from torchvision.prototype import features
from torchvision.transforms import functional as _F, InterpolationMode
from ._meta_conversion import convert_bounding_box_format
T = TypeVar("T", bound=features._Feature)
horizontal_flip_image = _F.hflip
def horizontal_flip_bounding_box(
bounding_box: torch.Tensor, *, format: features.BoundingBoxFormat, image_size: Tuple[int, int]
) -> torch.Tensor:
shape = bounding_box.shape
bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [2, 0]]
return convert_bounding_box_format(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
).view(shape)
def resize_image(
image: torch.Tensor,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> torch.Tensor:
new_height, new_width = size
num_channels, old_height, old_width = image.shape[-3:]
batch_shape = image.shape[:-3]
return _F.resize(
image.reshape((-1, num_channels, old_height, old_width)),
size=size,
interpolation=interpolation,
max_size=max_size,
antialias=antialias,
).reshape(batch_shape + (num_channels, new_height, new_width))
def resize_segmentation_mask(
segmentation_mask: torch.Tensor,
size: List[int],
max_size: Optional[int] = None,
) -> torch.Tensor:
return resize_image(segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
# TODO: handle max_size
def resize_bounding_box(bounding_box: torch.Tensor, *, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor:
old_height, old_width = image_size
new_height, new_width = size
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape)
center_crop_image = _F.center_crop
resized_crop_image = _F.resized_crop
affine_image = _F.affine
rotate_image = _F.rotate
pad_image = _F.pad
crop_image = _F.crop
perspective_image = _F.perspective
vertical_flip_image = _F.vflip
five_crop_image = _F.five_crop
ten_crop_image = _F.ten_crop
import torch
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms.functional_tensor import rgb_to_grayscale as _rgb_to_grayscale
def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:
xyxy = xywh.clone()
xyxy[..., 2:] += xyxy[..., :2]
return xyxy
def _xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
xywh = xyxy.clone()
xywh[..., 2:] -= xywh[..., :2]
return xywh
def _cxcywh_to_xyxy(cxcywh: torch.Tensor) -> torch.Tensor:
cx, cy, w, h = torch.unbind(cxcywh, dim=-1)
x1 = cx - 0.5 * w
y1 = cy - 0.5 * h
x2 = cx + 0.5 * w
y2 = cy + 0.5 * h
return torch.stack((x1, y1, x2, y2), dim=-1)
def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor:
x1, y1, x2, y2 = torch.unbind(xyxy, dim=-1)
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
return torch.stack((cx, cy, w, h), dim=-1)
def convert_bounding_box_format(
bounding_box: torch.Tensor, *, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat
) -> torch.Tensor:
if new_format == old_format:
return bounding_box.clone()
if old_format == BoundingBoxFormat.XYWH:
bounding_box = _xywh_to_xyxy(bounding_box)
elif old_format == BoundingBoxFormat.CXCYWH:
bounding_box = _cxcywh_to_xyxy(bounding_box)
if new_format == BoundingBoxFormat.XYWH:
bounding_box = _xyxy_to_xywh(bounding_box)
elif new_format == BoundingBoxFormat.CXCYWH:
bounding_box = _xyxy_to_cxcywh(bounding_box)
return bounding_box
def _grayscale_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
return grayscale.expand(3, 1, 1)
def convert_color_space(image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace) -> torch.Tensor:
if new_color_space == old_color_space:
return image.clone()
if old_color_space == ColorSpace.GRAYSCALE:
image = _grayscale_to_rgb(image)
if new_color_space == ColorSpace.GRAYSCALE:
image = _rgb_to_grayscale(image)
return image
from torchvision.transforms import functional as _F
normalize_image = _F.normalize
gaussian_blur_image = _F.gaussian_blur
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment