"tests/vscode:/vscode.git/clone" did not exist on "ae4a5b739412d817da36b86c858f00e9605022a9"
Unverified Commit 7251769f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Transforms without dispatcher (#5421)



* add prototype transforms that don't need dispatchers

* cleanup

* remove legacy_transform decorator

* remove legacy classes

* remove explicit param passing

* streamline extra_repr

* remove obsolete ._supports() method

* cleanup

* remove Query

* cleanup

* fix tests

* kernels -> functional

* move image size and num channels extraction to functional

* extend legacy function to extract image size and num channels

* implement dispatching for auto augment

* fix auto augment dispatch

* revert some naming changes

* remove ability to pass params to autoaugment

* fix legacy image size extraction

* align prototype.transforms.functional with transforms.functional

* cleanup

* fix image size and channels extraction

* fix affine and rotate

* revert image size to (width, height)

* Minor corrections
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent f15ba56f
import inspect from typing import Tuple, Union, cast
from typing import Any, Optional, Callable, TypeVar, Mapping, Type
import PIL.Image
import torch import torch
import torch.overrides
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
F = TypeVar("F", bound=features._Feature)
def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int]:
class Dispatcher: if isinstance(image, features.Image):
"""Wrap a function to automatically dispatch to registered kernels based on the call arguments. height, width = image.image_size
return width, height
The wrapped function should have this signature elif isinstance(image, torch.Tensor):
return cast(Tuple[int, int], tuple(_FT.get_image_size(image)))
.. code:: python if isinstance(image, PIL.Image.Image):
return cast(Tuple[int, int], tuple(_FP.get_image_size(image)))
@dispatch( else:
... raise TypeError(f"unable to get image size from object of type {type(image).__name__}")
)
def dispatch_fn(input, *args, **kwargs):
... def get_image_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int:
if isinstance(image, features.Image):
where ``input`` is used to determine which kernel to dispatch to. return image.num_channels
elif isinstance(image, torch.Tensor):
Args: return _FT.get_image_num_channels(image)
kernels: Dictionary with types as keys that maps to a kernel to call. The resolution order is checking for if isinstance(image, PIL.Image.Image):
exact type matches first and if none is found falls back to checking for subclasses. If a value is return cast(int, _FP.get_image_num_channels(image))
``None``, the decorated function is called. else:
raise TypeError(f"unable to get num channels from object of type {type(image).__name__}")
Raises:
TypeError: If any value in ``kernels`` is not callable with ``kernel(input, *args, **kwargs)``.
TypeError: If the decorated function is called with an input that cannot be dispatched.
"""
def __init__(self, fn: Callable, kernels: Mapping[Type, Optional[Callable]]):
self._fn = fn
for feature_type, kernel in kernels.items():
if not self._check_kernel(kernel):
raise TypeError(
f"Kernel for feature type {feature_type.__name__} is not callable with "
f"kernel(input, *args, **kwargs)."
)
self._kernels = kernels
def _check_kernel(self, kernel: Optional[Callable]) -> bool:
if kernel is None:
return True
if not callable(kernel):
return False
params = list(inspect.signature(kernel).parameters.values())
if not params:
return False
return params[0].kind != inspect.Parameter.KEYWORD_ONLY
def _resolve(self, feature_type: Type) -> Optional[Callable]:
try:
return self._kernels[feature_type]
except KeyError:
try:
return next(
kernel
for registered_feature_type, kernel in self._kernels.items()
if issubclass(feature_type, registered_feature_type)
)
except StopIteration:
raise TypeError(f"No support for feature type {feature_type.__name__}") from None
def __contains__(self, obj: Any) -> bool:
try:
self._resolve(type(obj))
return True
except TypeError:
return False
def __call__(self, input: Any, *args: Any, **kwargs: Any) -> Any:
kernel = self._resolve(type(input))
if kernel is None:
output = self._fn(input, *args, **kwargs)
if output is None:
raise RuntimeError(
f"{self._fn.__name__}() did not handle inputs of type {type(input).__name__} "
f"although it was configured to do so."
)
else:
output = kernel(input, *args, **kwargs)
if isinstance(input, features._Feature) and type(output) is torch.Tensor:
output = type(input).new_like(input, output)
return output
def dispatch(kernels: Mapping[Type, Optional[Callable]]) -> Callable[[Callable], Dispatcher]:
"""Decorates a function and turns it into a :class:`Dispatcher`."""
return lambda fn: Dispatcher(fn, kernels)
from torchvision.transforms import InterpolationMode # usort: skip
from ._meta_conversion import convert_bounding_box_format, convert_color_space # usort: skip
from ._augment import (
erase_image,
mixup_image,
mixup_one_hot_label,
cutmix_image,
cutmix_one_hot_label,
)
from ._color import (
adjust_brightness_image,
adjust_contrast_image,
adjust_saturation_image,
adjust_sharpness_image,
posterize_image,
solarize_image,
autocontrast_image,
equalize_image,
invert_image,
adjust_hue_image,
adjust_gamma_image,
)
from ._geometry import (
horizontal_flip_bounding_box,
horizontal_flip_image,
resize_bounding_box,
resize_image,
resize_segmentation_mask,
center_crop_image,
resized_crop_image,
affine_image,
rotate_image,
pad_image,
crop_image,
perspective_image,
vertical_flip_image,
five_crop_image,
ten_crop_image,
)
from ._misc import normalize_image, gaussian_blur_image
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
from typing import Tuple
import torch
from torchvision.transforms import functional as _F
erase_image = _F.erase
def _mixup(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor:
input = input.clone()
return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam))
def mixup_image(image_batch: torch.Tensor, *, lam: float) -> torch.Tensor:
if image_batch.ndim < 4:
raise ValueError("Need a batch of images")
return _mixup(image_batch, -4, lam)
def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float) -> torch.Tensor:
if one_hot_label_batch.ndim < 2:
raise ValueError("Need a batch of one hot labels")
return _mixup(one_hot_label_batch, -2, lam)
def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int]) -> torch.Tensor:
if image_batch.ndim < 4:
raise ValueError("Need a batch of images")
x1, y1, x2, y2 = box
image_rolled = image_batch.roll(1, -4)
image_batch = image_batch.clone()
image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
return image_batch
def cutmix_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam_adjusted: float) -> torch.Tensor:
if one_hot_label_batch.ndim < 2:
raise ValueError("Need a batch of one hot labels")
return _mixup(one_hot_label_batch, -2, lam_adjusted)
from torchvision.transforms import functional as _F
adjust_brightness_image = _F.adjust_brightness
adjust_saturation_image = _F.adjust_saturation
adjust_contrast_image = _F.adjust_contrast
adjust_sharpness_image = _F.adjust_sharpness
posterize_image = _F.posterize
solarize_image = _F.solarize
autocontrast_image = _F.autocontrast
equalize_image = _F.equalize
invert_image = _F.invert
adjust_hue_image = _F.adjust_hue
adjust_gamma_image = _F.adjust_gamma
from typing import Tuple, List, Optional, TypeVar
import torch
from torchvision.prototype import features
from torchvision.transforms import functional as _F, InterpolationMode
from ._meta_conversion import convert_bounding_box_format
T = TypeVar("T", bound=features._Feature)
horizontal_flip_image = _F.hflip
def horizontal_flip_bounding_box(
bounding_box: torch.Tensor, *, format: features.BoundingBoxFormat, image_size: Tuple[int, int]
) -> torch.Tensor:
shape = bounding_box.shape
bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [2, 0]]
return convert_bounding_box_format(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
).view(shape)
def resize_image(
image: torch.Tensor,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> torch.Tensor:
new_height, new_width = size
num_channels, old_height, old_width = image.shape[-3:]
batch_shape = image.shape[:-3]
return _F.resize(
image.reshape((-1, num_channels, old_height, old_width)),
size=size,
interpolation=interpolation,
max_size=max_size,
antialias=antialias,
).reshape(batch_shape + (num_channels, new_height, new_width))
def resize_segmentation_mask(
segmentation_mask: torch.Tensor,
size: List[int],
max_size: Optional[int] = None,
) -> torch.Tensor:
return resize_image(segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
# TODO: handle max_size
def resize_bounding_box(bounding_box: torch.Tensor, *, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor:
old_height, old_width = image_size
new_height, new_width = size
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape)
center_crop_image = _F.center_crop
resized_crop_image = _F.resized_crop
affine_image = _F.affine
rotate_image = _F.rotate
pad_image = _F.pad
crop_image = _F.crop
perspective_image = _F.perspective
vertical_flip_image = _F.vflip
five_crop_image = _F.five_crop
ten_crop_image = _F.ten_crop
import torch
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms.functional_tensor import rgb_to_grayscale as _rgb_to_grayscale
def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:
xyxy = xywh.clone()
xyxy[..., 2:] += xyxy[..., :2]
return xyxy
def _xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
xywh = xyxy.clone()
xywh[..., 2:] -= xywh[..., :2]
return xywh
def _cxcywh_to_xyxy(cxcywh: torch.Tensor) -> torch.Tensor:
cx, cy, w, h = torch.unbind(cxcywh, dim=-1)
x1 = cx - 0.5 * w
y1 = cy - 0.5 * h
x2 = cx + 0.5 * w
y2 = cy + 0.5 * h
return torch.stack((x1, y1, x2, y2), dim=-1)
def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor:
x1, y1, x2, y2 = torch.unbind(xyxy, dim=-1)
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
return torch.stack((cx, cy, w, h), dim=-1)
def convert_bounding_box_format(
bounding_box: torch.Tensor, *, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat
) -> torch.Tensor:
if new_format == old_format:
return bounding_box.clone()
if old_format == BoundingBoxFormat.XYWH:
bounding_box = _xywh_to_xyxy(bounding_box)
elif old_format == BoundingBoxFormat.CXCYWH:
bounding_box = _cxcywh_to_xyxy(bounding_box)
if new_format == BoundingBoxFormat.XYWH:
bounding_box = _xyxy_to_xywh(bounding_box)
elif new_format == BoundingBoxFormat.CXCYWH:
bounding_box = _xyxy_to_cxcywh(bounding_box)
return bounding_box
def _grayscale_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
return grayscale.expand(3, 1, 1)
def convert_color_space(image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace) -> torch.Tensor:
if new_color_space == old_color_space:
return image.clone()
if old_color_space == ColorSpace.GRAYSCALE:
image = _grayscale_to_rgb(image)
if new_color_space == ColorSpace.GRAYSCALE:
image = _rgb_to_grayscale(image)
return image
from torchvision.transforms import functional as _F
normalize_image = _F.normalize
gaussian_blur_image = _F.gaussian_blur
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment