Unverified Commit 48a61df2 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding AugMix implementation (#5411)



* Adding basic augmix implementation.

* Finish the implementation.

* Add tests and documentation.

* Fix tests.

* Simplify code.

* Speed optimizations.

* Per image weights instead of per batch.

* Fix tests.

* Update torchvision/transforms/autoaugment.py
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>

* Changing the default severity value to get by default the same strength as RandAugment.
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent e13206d9
...@@ -198,6 +198,7 @@ The new transform can be used standalone or mixed-and-matched with existing tran ...@@ -198,6 +198,7 @@ The new transform can be used standalone or mixed-and-matched with existing tran
AutoAugment AutoAugment
RandAugment RandAugment
TrivialAugmentWide TrivialAugmentWide
AugMix
.. _functional_transforms: .. _functional_transforms:
......
...@@ -263,6 +263,14 @@ augmenter = T.TrivialAugmentWide() ...@@ -263,6 +263,14 @@ augmenter = T.TrivialAugmentWide()
imgs = [augmenter(orig_img) for _ in range(4)] imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs) plot(imgs)
####################################
# AugMix
# ~~~~~~
# The :class:`~torchvision.transforms.AugMix` transform automatically augments the data.
augmenter = T.AugMix()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)
#################################### ####################################
# Randomly-applied transforms # Randomly-applied transforms
# --------------------------- # ---------------------------
......
...@@ -22,6 +22,8 @@ class ClassificationPresetTrain: ...@@ -22,6 +22,8 @@ class ClassificationPresetTrain:
trans.append(autoaugment.RandAugment(interpolation=interpolation)) trans.append(autoaugment.RandAugment(interpolation=interpolation))
elif auto_augment_policy == "ta_wide": elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix":
trans.append(autoaugment.AugMix(interpolation=interpolation))
else: else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
......
...@@ -1601,6 +1601,25 @@ def test_trivialaugmentwide(fill, num_magnitude_bins, grayscale): ...@@ -1601,6 +1601,25 @@ def test_trivialaugmentwide(fill, num_magnitude_bins, grayscale):
transform.__repr__() transform.__repr__()
@pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)])
@pytest.mark.parametrize("severity", [1, 10])
@pytest.mark.parametrize("mixture_width", [1, 2])
@pytest.mark.parametrize("chain_depth", [-1, 2])
@pytest.mark.parametrize("all_ops", [True, False])
@pytest.mark.parametrize("grayscale", [True, False])
def test_augmix(fill, severity, mixture_width, chain_depth, all_ops, grayscale):
random.seed(42)
img = Image.open(GRACE_HOPPER)
if grayscale:
img, fill = _get_grayscale_test_image(img, fill)
transform = transforms.AugMix(
fill=fill, severity=severity, mixture_width=mixture_width, chain_depth=chain_depth, all_ops=all_ops
)
for _ in range(100):
img = transform(img)
transform.__repr__()
def test_random_crop(): def test_random_crop():
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2
......
...@@ -720,7 +720,38 @@ def test_trivialaugmentwide(device, fill): ...@@ -720,7 +720,38 @@ def test_trivialaugmentwide(device, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
@pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide]) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"fill",
[
None,
85,
(10, -10, 10),
0.7,
[0.0, 0.0, 0.0],
[
1,
],
1,
],
)
def test_augmix(device, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
class DeterministicAugMix(T.AugMix):
def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
# patch the method to ensure that the order of rand calls doesn't affect the outcome
return params.softmax(dim=-1)
transform = DeterministicAugMix(fill=fill)
s_transform = torch.jit.script(transform)
for _ in range(25):
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
@pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide, T.AugMix])
def test_autoaugment_save(augmentation, tmpdir): def test_autoaugment_save(augmentation, tmpdir):
transform = augmentation() transform = augmentation()
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
......
...@@ -7,7 +7,7 @@ from torch import Tensor ...@@ -7,7 +7,7 @@ from torch import Tensor
from . import functional as F, InterpolationMode from . import functional as F, InterpolationMode
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"] __all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide", "AugMix"]
def _apply_op( def _apply_op(
...@@ -458,3 +458,154 @@ class TrivialAugmentWide(torch.nn.Module): ...@@ -458,3 +458,154 @@ class TrivialAugmentWide(torch.nn.Module):
f")" f")"
) )
return s return s
class AugMix(torch.nn.Module):
r"""AugMix data augmentation method based on
`"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
severity (int): The severity of base augmentation operators. Default is ``3``.
mixture_width (int): The number of augmentation chains. Default is ``3``.
chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
Default is ``-1``.
alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``.
all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
def __init__(
self,
severity: int = 3,
mixture_width: int = 3,
chain_depth: int = -1,
alpha: float = 1.0,
all_ops: bool = True,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> None:
super().__init__()
self._PARAMETER_MAX = 10
if not (1 <= severity <= self._PARAMETER_MAX):
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
self.severity = severity
self.mixture_width = mixture_width
self.chain_depth = chain_depth
self.alpha = alpha
self.all_ops = all_ops
self.interpolation = interpolation
self.fill = fill
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
s = {
# op_name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True),
"TranslateY": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True),
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
"Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(0.0), False),
"Equalize": (torch.tensor(0.0), False),
}
if self.all_ops:
s.update(
{
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
}
)
return s
@torch.jit.unused
def _pil_to_tensor(self, img) -> Tensor:
return F.pil_to_tensor(img)
@torch.jit.unused
def _tensor_to_pil(self, img: Tensor):
return F.to_pil_image(img)
def _sample_dirichlet(self, params: Tensor) -> Tensor:
# Must be on a separate method so that we can overwrite it in tests.
return torch._sample_dirichlet(params)
def forward(self, orig_img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
if isinstance(orig_img, Tensor):
img = orig_img
if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img)
elif fill is not None:
fill = [float(f) for f in fill]
else:
img = self._pil_to_tensor(orig_img)
op_meta = self._augmentation_space(self._PARAMETER_MAX, F.get_image_size(img))
orig_dims = list(img.shape)
batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
# Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet
# with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image.
m = self._sample_dirichlet(
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
)
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images.
combined_weights = self._sample_dirichlet(
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
) * m[:, 1].view([batch_dims[0], -1])
mix = m[:, 0].view(batch_dims) * batch
for i in range(self.mixture_width):
aug = batch
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
for _ in range(depth):
op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index]
magnitudes, signed = op_meta[op_name]
magnitude = (
float(magnitudes[torch.randint(self.severity, (1,), dtype=torch.long)].item())
if magnitudes.ndim > 0
else 0.0
)
if signed and torch.randint(2, (1,)):
magnitude *= -1.0
aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill)
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
mix = mix.view(orig_dims).to(dtype=img.dtype)
if not isinstance(orig_img, Tensor):
return self._tensor_to_pil(mix)
return mix
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"severity={self.severity}"
f", mixture_width={self.mixture_width}"
f", chain_depth={self.chain_depth}"
f", alpha={self.alpha}"
f", all_ops={self.all_ops}"
f", interpolation={self.interpolation}"
f", fill={self.fill}"
f")"
)
return s
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment