Unverified Commit 3a7e5e38 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Refactor AutoAugment to support more augmentations. (#4338)

parent 72393f2b
......@@ -10,6 +10,45 @@ from . import functional as F, InterpolationMode
__all__ = ["AutoAugmentPolicy", "AutoAugment"]
def _apply_op(img: Tensor, op_name: str, magnitude: float,
interpolation: InterpolationMode, fill: Optional[List[float]]):
if op_name == "ShearX":
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0],
interpolation=interpolation, fill=fill)
elif op_name == "ShearY":
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)],
interpolation=interpolation, fill=fill)
elif op_name == "TranslateX":
img = F.affine(img, angle=0.0, translate=[int(magnitude), 0], scale=1.0,
interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
elif op_name == "TranslateY":
img = F.affine(img, angle=0.0, translate=[0, int(magnitude)], scale=1.0,
interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
elif op_name == "Rotate":
img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
elif op_name == "Brightness":
img = F.adjust_brightness(img, 1.0 + magnitude)
elif op_name == "Color":
img = F.adjust_saturation(img, 1.0 + magnitude)
elif op_name == "Contrast":
img = F.adjust_contrast(img, 1.0 + magnitude)
elif op_name == "Sharpness":
img = F.adjust_sharpness(img, 1.0 + magnitude)
elif op_name == "Posterize":
img = F.posterize(img, int(magnitude))
elif op_name == "Solarize":
img = F.solarize(img, magnitude)
elif op_name == "AutoContrast":
img = F.autocontrast(img)
elif op_name == "Equalize":
img = F.equalize(img)
elif op_name == "Invert":
img = F.invert(img)
else:
raise ValueError("The provided operator {} is not recognized.".format(op_name))
return img
class AutoAugmentPolicy(Enum):
"""AutoAugment policies learned on different datasets.
Available policies are IMAGENET, CIFAR10 and SVHN.
......@@ -19,9 +58,39 @@ class AutoAugmentPolicy(Enum):
SVHN = "svhn"
def _get_transforms( # type: ignore[return]
class AutoAugment(torch.nn.Module):
r"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
policy (AutoAugmentPolicy): Desired policy enum defined by
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
def __init__(
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None
) -> None:
super().__init__()
self.policy = policy
self.interpolation = interpolation
self.fill = fill
self.transforms = self._get_transforms(policy)
def _get_transforms(
self,
policy: AutoAugmentPolicy
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
if policy == AutoAugmentPolicy.IMAGENET:
return [
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
......@@ -106,62 +175,28 @@ def _get_transforms( # type: ignore[return]
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
]
else:
raise ValueError("The provided policy {} is not recognized.".format(policy))
def _get_magnitudes() -> Dict[str, Tuple[Optional[Tensor], Optional[bool]]]:
_BINS = 10
def _get_magnitudes(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
return {
# name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, _BINS), True),
"ShearY": (torch.linspace(0.0, 0.3, _BINS), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True),
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True),
"Rotate": (torch.linspace(0.0, 30.0, _BINS), True),
"Brightness": (torch.linspace(0.0, 0.9, _BINS), True),
"Color": (torch.linspace(0.0, 0.9, _BINS), True),
"Contrast": (torch.linspace(0.0, 0.9, _BINS), True),
"Sharpness": (torch.linspace(0.0, 0.9, _BINS), True),
"Posterize": (torch.tensor([8, 8, 7, 7, 6, 6, 5, 5, 4, 4]), False),
"Solarize": (torch.linspace(256.0, 0.0, _BINS), False),
"AutoContrast": (None, None),
"Equalize": (None, None),
"Invert": (None, None),
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(0.0), False),
"Equalize": (torch.tensor(0.0), False),
"Invert": (torch.tensor(0.0), False),
}
class AutoAugment(torch.nn.Module):
r"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
policy (AutoAugmentPolicy): Desired policy enum defined by
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
def __init__(
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None
) -> None:
super().__init__()
self.policy = policy
self.interpolation = interpolation
self.fill = fill
self.transforms = _get_transforms(policy)
if self.transforms is None:
raise ValueError("The provided policy {} is not recognized.".format(policy))
self._op_meta = _get_magnitudes()
@staticmethod
def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
"""Get parameters for autoaugment transformation
......@@ -175,9 +210,6 @@ class AutoAugment(torch.nn.Module):
return policy_id, probs, signs
def _get_op_meta(self, name: str) -> Tuple[Optional[Tensor], Optional[bool]]:
return self._op_meta[name]
def forward(self, img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.
......@@ -196,46 +228,12 @@ class AutoAugment(torch.nn.Module):
for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]):
if probs[i] <= p:
magnitudes, signed = self._get_op_meta(op_name)
magnitude = float(magnitudes[magnitude_id].item()) \
if magnitudes is not None and magnitude_id is not None else 0.0
if signed is not None and signed and signs[i] == 0:
op_meta = self._get_magnitudes(10, F.get_image_size(img))
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
if signed and signs[i] == 0:
magnitude *= -1.0
if op_name == "ShearX":
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0],
interpolation=self.interpolation, fill=fill)
elif op_name == "ShearY":
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)],
interpolation=self.interpolation, fill=fill)
elif op_name == "TranslateX":
img = F.affine(img, angle=0.0, translate=[int(F.get_image_size(img)[0] * magnitude), 0], scale=1.0,
interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill)
elif op_name == "TranslateY":
img = F.affine(img, angle=0.0, translate=[0, int(F.get_image_size(img)[1] * magnitude)], scale=1.0,
interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill)
elif op_name == "Rotate":
img = F.rotate(img, magnitude, interpolation=self.interpolation, fill=fill)
elif op_name == "Brightness":
img = F.adjust_brightness(img, 1.0 + magnitude)
elif op_name == "Color":
img = F.adjust_saturation(img, 1.0 + magnitude)
elif op_name == "Contrast":
img = F.adjust_contrast(img, 1.0 + magnitude)
elif op_name == "Sharpness":
img = F.adjust_sharpness(img, 1.0 + magnitude)
elif op_name == "Posterize":
img = F.posterize(img, int(magnitude))
elif op_name == "Solarize":
img = F.solarize(img, magnitude)
elif op_name == "AutoContrast":
img = F.autocontrast(img)
elif op_name == "Equalize":
img = F.equalize(img)
elif op_name == "Invert":
img = F.invert(img)
else:
raise ValueError("The provided operator {} is not recognized.".format(op_name))
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
return img
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment