Unverified Commit 95d41897 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add prototype AugMix transform (#5492)

* add prototype AugMix transform

* cleanup

* refactor auto augment subclasses to only trnasform a single image

* address review comments
parent 7767f120
......@@ -114,6 +114,7 @@ class TestSmoke:
transforms.RandAugment(),
transforms.TrivialAugmentWide(),
transforms.AutoAugment(),
transforms.AugMix(),
)
]
)
......
......@@ -5,7 +5,7 @@ from . import functional # usort: skip
from ._transform import Transform # usort: skip
from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
......
import math
from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar
from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar, Union
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F
from torchvision.prototype.utils._internal import apply_recursively
from torchvision.prototype.utils._internal import query_recursively
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from ._utils import query_image, get_image_dimensions
from ._utils import get_image_dimensions
K = TypeVar("K")
V = TypeVar("V")
def _put_into_sample(sample: Any, id: Tuple[Any, ...], item: Any) -> Any:
if not id:
return item
parent = sample
for key in id[:-1]:
parent = parent[key]
parent[id[-1]] = item
return sample
class _AutoAugmentBase(Transform):
def __init__(
self, *, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None
......@@ -26,50 +39,77 @@ class _AutoAugmentBase(Transform):
key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]
def _apply_transform(self, sample: Any, transform_id: str, magnitude: float) -> Any:
def dispatch(
def _check_unsupported(self, input: Any) -> None:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
def _extract_image(
self, sample: Any
) -> Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]:
def fn(
id: Tuple[Any, ...], input: Any
) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]:
if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image):
return id, input
self._check_unsupported(input)
return None
images = list(query_recursively(fn, sample))
if not images:
raise TypeError("Found no image in the sample.")
if len(images) > 1:
raise TypeError(
f"Auto augment transformations are only properly defined for a single image, but found {len(images)}."
)
return images[0]
def _parse_fill(
self, image: Union[PIL.Image.Image, torch.Tensor, features.Image], num_channels: int
) -> Optional[List[float]]:
fill = self.fill
if isinstance(image, PIL.Image.Image) or fill is None:
return fill
if isinstance(fill, (int, float)):
fill = [float(fill)] * num_channels
else:
fill = [float(f) for f in fill]
return fill
def _dispatch_image_kernels(
self,
image_tensor_kernel: Callable,
image_pil_kernel: Callable,
input: Any,
*args: Any,
**kwargs: Any,
) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
if isinstance(input, features.Image):
output = image_tensor_kernel(input, *args, **kwargs)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
return image_tensor_kernel(input, *args, **kwargs)
elif isinstance(input, PIL.Image.Image):
else: # isinstance(input, PIL.Image.Image):
return image_pil_kernel(input, *args, **kwargs)
else:
return input
image = query_image(sample)
num_channels, *_ = get_image_dimensions(image)
fill = self.fill
if isinstance(fill, (int, float)):
fill = [float(fill)] * num_channels
elif fill is not None:
fill = [float(f) for f in fill]
interpolation = self.interpolation
def transform(input: Any) -> Any:
if type(input) in {features.BoundingBox, features.SegmentationMask}:
raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()")
elif not (type(input) in {features.Image, torch.Tensor} or isinstance(input, PIL.Image.Image)):
return input
def _apply_image_transform(
self,
image: Any,
transform_id: str,
magnitude: float,
interpolation: InterpolationMode,
fill: Optional[List[float]],
) -> Any:
if transform_id == "Identity":
return input
return image
elif transform_id == "ShearX":
return dispatch(
return self._dispatch_image_kernels(
F.affine_image_tensor,
F.affine_image_pil,
input,
image,
angle=0.0,
translate=[0, 0],
scale=1.0,
......@@ -78,10 +118,10 @@ class _AutoAugmentBase(Transform):
fill=fill,
)
elif transform_id == "ShearY":
return dispatch(
return self._dispatch_image_kernels(
F.affine_image_tensor,
F.affine_image_pil,
input,
image,
angle=0.0,
translate=[0, 0],
scale=1.0,
......@@ -90,10 +130,10 @@ class _AutoAugmentBase(Transform):
fill=fill,
)
elif transform_id == "TranslateX":
return dispatch(
return self._dispatch_image_kernels(
F.affine_image_tensor,
F.affine_image_pil,
input,
image,
angle=0.0,
translate=[int(magnitude), 0],
scale=1.0,
......@@ -102,10 +142,10 @@ class _AutoAugmentBase(Transform):
fill=fill,
)
elif transform_id == "TranslateY":
return dispatch(
return self._dispatch_image_kernels(
F.affine_image_tensor,
F.affine_image_pil,
input,
image,
angle=0.0,
translate=[0, int(magnitude)],
scale=1.0,
......@@ -114,47 +154,49 @@ class _AutoAugmentBase(Transform):
fill=fill,
)
elif transform_id == "Rotate":
return dispatch(F.rotate_image_tensor, F.rotate_image_pil, input, angle=magnitude)
return self._dispatch_image_kernels(F.rotate_image_tensor, F.rotate_image_pil, image, angle=magnitude)
elif transform_id == "Brightness":
return dispatch(
return self._dispatch_image_kernels(
F.adjust_brightness_image_tensor,
F.adjust_brightness_image_pil,
input,
image,
brightness_factor=1.0 + magnitude,
)
elif transform_id == "Color":
return dispatch(
return self._dispatch_image_kernels(
F.adjust_saturation_image_tensor,
F.adjust_saturation_image_pil,
input,
image,
saturation_factor=1.0 + magnitude,
)
elif transform_id == "Contrast":
return dispatch(
F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, input, contrast_factor=1.0 + magnitude
return self._dispatch_image_kernels(
F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, image, contrast_factor=1.0 + magnitude
)
elif transform_id == "Sharpness":
return dispatch(
return self._dispatch_image_kernels(
F.adjust_sharpness_image_tensor,
F.adjust_sharpness_image_pil,
input,
image,
sharpness_factor=1.0 + magnitude,
)
elif transform_id == "Posterize":
return dispatch(F.posterize_image_tensor, F.posterize_image_pil, input, bits=int(magnitude))
return self._dispatch_image_kernels(
F.posterize_image_tensor, F.posterize_image_pil, image, bits=int(magnitude)
)
elif transform_id == "Solarize":
return dispatch(F.solarize_image_tensor, F.solarize_image_pil, input, threshold=magnitude)
return self._dispatch_image_kernels(
F.solarize_image_tensor, F.solarize_image_pil, image, threshold=magnitude
)
elif transform_id == "AutoContrast":
return dispatch(F.autocontrast_image_tensor, F.autocontrast_image_pil, input)
return self._dispatch_image_kernels(F.autocontrast_image_tensor, F.autocontrast_image_pil, image)
elif transform_id == "Equalize":
return dispatch(F.equalize_image_tensor, F.equalize_image_pil, input)
return self._dispatch_image_kernels(F.equalize_image_tensor, F.equalize_image_pil, image)
elif transform_id == "Invert":
return dispatch(F.invert_image_tensor, F.invert_image_pil, input)
return self._dispatch_image_kernels(F.invert_image_tensor, F.invert_image_pil, image)
else:
raise ValueError(f"No transform available for {transform_id}")
return apply_recursively(transform, sample)
class AutoAugment(_AutoAugmentBase):
_AUGMENTATION_SPACE = {
......@@ -277,8 +319,9 @@ class AutoAugment(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
image = query_image(sample)
_, height, width = get_image_dimensions(image)
id, image = self._extract_image(sample)
num_channels, height, width = get_image_dimensions(image)
fill = self._parse_fill(image, num_channels)
policy = self._policies[int(torch.randint(len(self._policies), ()))]
......@@ -296,9 +339,11 @@ class AutoAugment(_AutoAugmentBase):
else:
magnitude = 0.0
sample = self._apply_transform(sample, transform_id, magnitude)
image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=fill
)
return sample
return _put_into_sample(sample, id, image)
class RandAugment(_AutoAugmentBase):
......@@ -333,8 +378,9 @@ class RandAugment(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
image = query_image(sample)
_, height, width = get_image_dimensions(image)
id, image = self._extract_image(sample)
num_channels, height, width = get_image_dimensions(image)
fill = self._parse_fill(image, num_channels)
for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
......@@ -347,9 +393,11 @@ class RandAugment(_AutoAugmentBase):
else:
magnitude = 0.0
sample = self._apply_transform(sample, transform_id, magnitude)
image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=fill
)
return sample
return _put_into_sample(sample, id, image)
class TrivialAugmentWide(_AutoAugmentBase):
......@@ -382,8 +430,9 @@ class TrivialAugmentWide(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
image = query_image(sample)
_, height, width = get_image_dimensions(image)
id, image = self._extract_image(sample)
num_channels, height, width = get_image_dimensions(image)
fill = self._parse_fill(image, num_channels)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
......@@ -395,4 +444,110 @@ class TrivialAugmentWide(_AutoAugmentBase):
else:
magnitude = 0.0
return self._apply_transform(sample, transform_id, magnitude)
image = self._apply_image_transform(image, transform_id, magnitude, interpolation=self.interpolation, fill=fill)
return _put_into_sample(sample, id, image)
class AugMix(_AutoAugmentBase):
_PARTIAL_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, image_size[1] / 3.0, num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, image_size[0] / 3.0, num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Posterize": (
lambda num_bins, image_size: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round()
.int(),
False,
),
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, image_size: None, False),
"Equalize": (lambda num_bins, image_size: None, False),
}
_AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, Tuple[int, int]], Optional[torch.Tensor]], bool]] = {
**_PARTIAL_AUGMENTATION_SPACE,
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
}
def __init__(
self,
severity: int = 3,
mixture_width: int = 3,
chain_depth: int = -1,
alpha: float = 1.0,
all_ops: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._PARAMETER_MAX = 10
if not (1 <= severity <= self._PARAMETER_MAX):
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
self.severity = severity
self.mixture_width = mixture_width
self.chain_depth = chain_depth
self.alpha = alpha
self.all_ops = all_ops
def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
# Must be on a separate method so that we can overwrite it in tests.
return torch._sample_dirichlet(params)
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, orig_image = self._extract_image(sample)
num_channels, height, width = get_image_dimensions(orig_image)
fill = self._parse_fill(orig_image, num_channels)
if isinstance(orig_image, torch.Tensor):
image = orig_image
else: # isinstance(input, PIL.Image.Image):
image = pil_to_tensor(orig_image)
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
orig_dims = list(image.shape)
batch = image.view([1] * max(4 - image.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
# Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet
# with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image.
m = self._sample_dirichlet(
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
)
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images.
combined_weights = self._sample_dirichlet(
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
) * m[:, 1].view([batch_dims[0], -1])
mix = m[:, 0].view(batch_dims) * batch
for i in range(self.mixture_width):
aug = batch
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
for _ in range(depth):
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
magnitudes = magnitudes_fn(self._PARAMETER_MAX, (height, width))
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
aug = self._apply_image_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=fill
)
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
mix = mix.view(orig_dims).to(dtype=image.dtype)
if isinstance(orig_image, features.Image):
mix = features.Image.new_like(orig_image, mix)
elif isinstance(orig_image, PIL.Image.Image):
mix = to_pil_image(mix)
return _put_into_sample(sample, id, mix)
......@@ -9,14 +9,16 @@ from .functional._meta import get_dimensions_image_tensor, get_dimensions_image_
def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]:
def fn(
id: Tuple[Any, ...], input: Any
) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]:
if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image):
return input
return id, input
return None
try:
return next(query_recursively(fn, sample))
return next(query_recursively(fn, sample))[1]
except StopIteration:
raise TypeError("No image was found in the sample")
......
......@@ -312,15 +312,18 @@ def apply_recursively(fn: Callable, obj: Any) -> Any:
return fn(obj)
def query_recursively(fn: Callable[[Any], Optional[D]], obj: Any) -> Iterator[D]:
def query_recursively(
fn: Callable[[Tuple[Any, ...], Any], Optional[D]], obj: Any, *, id: Tuple[Any, ...] = ()
) -> Iterator[D]:
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
# "a" == "a"[0][0]...
if (isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str)) or isinstance(
obj, collections.abc.Mapping
):
for item in obj.values() if isinstance(obj, collections.abc.Mapping) else obj:
yield from query_recursively(fn, item)
if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str):
for idx, item in enumerate(obj):
yield from query_recursively(fn, item, id=(*id, idx))
elif isinstance(obj, collections.abc.Mapping):
for key, item in obj.items():
yield from query_recursively(fn, item, id=(*id, key))
else:
result = fn(obj)
result = fn(id, obj)
if result is not None:
yield result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment