"vscode:/vscode.git/clone" did not exist on "dba8c82571fd6a36075b343cbfefd5ff5c04f0cc"
Unverified Commit d8025b9a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

cleanup prototype auto augment transforms (#6463)

* cleanup prototype auto augment transforms

* remove custom fill parsing from auto augment
parent 38cd6b39
import math
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
import numbers
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
from ._utils import get_chw, is_simple_tensor
from ._utils import is_simple_tensor, query_chw
K = TypeVar("K")
V = TypeVar("V")
def _put_into_sample(sample: Any, id: int, item: Any) -> Any:
sample_flat, spec = tree_flatten(sample)
sample_flat[id] = item
return tree_unflatten(sample_flat, spec)
class _AutoAugmentBase(Transform):
def __init__(
self,
......@@ -31,6 +25,9 @@ class _AutoAugmentBase(Transform):
) -> None:
super().__init__()
self.interpolation = interpolation
if not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
self.fill = fill
def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
......@@ -38,41 +35,9 @@ class _AutoAugmentBase(Transform):
key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]
def _extract_image(
self,
sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask),
) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]:
sample_flat, _ = tree_flatten(sample)
images = []
for id, inpt in enumerate(sample_flat):
if isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt):
images.append((id, inpt))
elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
if not images:
raise TypeError("Found no image in the sample.")
if len(images) > 1:
raise TypeError(
f"Auto augment transformations are only properly defined for a single image, but found {len(images)}."
)
return images[0]
def _parse_fill(
self, image: Union[PIL.Image.Image, torch.Tensor, features.Image], num_channels: int
) -> Union[int, float, Sequence[int], Sequence[float]]:
fill = self.fill
if isinstance(image, PIL.Image.Image) or fill is None:
return fill
if isinstance(fill, (int, float)):
fill = [float(fill)] * num_channels
else:
fill = [float(f) for f in fill]
return fill
def _get_params(self, sample: Any) -> Dict[str, Any]:
_, height, width = query_chw(sample)
return dict(height=height, width=width)
def _apply_image_transform(
self,
......@@ -277,22 +242,22 @@ class AutoAugment(_AutoAugmentBase):
else:
raise ValueError(f"The provided policy {policy} is not recognized.")
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
def _get_params(self, sample: Any) -> Dict[str, Any]:
params = super(AutoAugment, self)._get_params(sample)
params["policy"] = self._policies[int(torch.randint(len(self._policies), ()))]
return params
id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)
fill = self._parse_fill(image, num_channels)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt
policy = self._policies[int(torch.randint(len(self._policies), ()))]
for transform_id, probability, magnitude_idx in policy:
for transform_id, probability, magnitude_idx in params["policy"]:
if not torch.rand(()) <= probability:
continue
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
magnitudes = magnitudes_fn(10, height, width)
magnitudes = magnitudes_fn(10, params["height"], params["width"])
if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5:
......@@ -300,11 +265,11 @@ class AutoAugment(_AutoAugmentBase):
else:
magnitude = 0.0
image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=fill
inpt = self._apply_image_transform(
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return _put_into_sample(sample, id, image)
return inpt
class RandAugment(_AutoAugmentBase):
......@@ -350,17 +315,14 @@ class RandAugment(_AutoAugmentBase):
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)
fill = self._parse_fill(image, num_channels)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt
for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
magnitudes = magnitudes_fn(self.num_magnitude_bins, params["height"], params["width"])
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
......@@ -368,11 +330,11 @@ class RandAugment(_AutoAugmentBase):
else:
magnitude = 0.0
image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=fill
inpt = self._apply_image_transform(
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return _put_into_sample(sample, id, image)
return inpt
class TrivialAugmentWide(_AutoAugmentBase):
......@@ -408,16 +370,13 @@ class TrivialAugmentWide(_AutoAugmentBase):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)
fill = self._parse_fill(image, num_channels)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
magnitudes = magnitudes_fn(self.num_magnitude_bins, params["height"], params["width"])
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
......@@ -425,8 +384,9 @@ class TrivialAugmentWide(_AutoAugmentBase):
else:
magnitude = 0.0
image = self._apply_image_transform(image, transform_id, magnitude, interpolation=self.interpolation, fill=fill)
return _put_into_sample(sample, id, image)
return self._apply_image_transform(
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
class AugMix(_AutoAugmentBase):
......@@ -478,16 +438,13 @@ class AugMix(_AutoAugmentBase):
# Must be on a separate method so that we can overwrite it in tests.
return torch._sample_dirichlet(params)
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, orig_image = self._extract_image(sample)
num_channels, height, width = get_chw(orig_image)
fill = self._parse_fill(orig_image, num_channels)
if isinstance(orig_image, torch.Tensor):
image = orig_image
else: # isinstance(inpt, PIL.Image.Image):
image = pil_to_tensor(orig_image)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image) or is_simple_tensor(inpt):
image = inpt
elif isinstance(inpt, PIL.Image.Image):
image = pil_to_tensor(inpt)
else:
return inpt
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
......@@ -513,7 +470,7 @@ class AugMix(_AutoAugmentBase):
for _ in range(depth):
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
magnitudes = magnitudes_fn(self._PARAMETER_MAX, params["height"], params["width"])
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
if signed and torch.rand(()) <= 0.5:
......@@ -522,14 +479,14 @@ class AugMix(_AutoAugmentBase):
magnitude = 0.0
aug = self._apply_image_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=fill
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
mix = mix.view(orig_dims).to(dtype=image.dtype)
if isinstance(orig_image, features.Image):
mix = features.Image.new_like(orig_image, mix)
elif isinstance(orig_image, PIL.Image.Image):
if isinstance(inpt, features.Image):
mix = features.Image.new_like(inpt, mix)
elif isinstance(inpt, PIL.Image.Image):
mix = to_pil_image(mix)
return _put_into_sample(sample, id, mix)
return mix
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment