You need to sign in or sign up before continuing.
Unverified Commit 446b2ca5 authored by SamuelGabriel's avatar SamuelGabriel Committed by GitHub
Browse files

Integration of TrivialAugment with the current AutoAugment Code (#4221)



* Initial Proposal

* Tensor Save Test + Test Name Fix

* Formatting + removing unused argument

* fix old argument

* fix isnan check error + indexing error with jit

* Fix Flake8 error.

* Fix MyPy error.

* Fix Flake8 error.

* Fix PyTorch JIT error in UnitTests due to type annotation.

* Fixing tests.

* Removing type ignore.

* Adding support of ta_wide in references.

* Move methods in classes.

* Moving new classes to the bottom.

* Specialize to TA to TAwide

* Add missing type

* Fixing lint

* Fix doc

* Fix search space of TrivialAugment.
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarVasilis Vryniotis <vvryniotis@fb.com>
parent 80d5f50d
...@@ -234,6 +234,11 @@ The new transform can be used standalone or mixed-and-matched with existing tran ...@@ -234,6 +234,11 @@ The new transform can be used standalone or mixed-and-matched with existing tran
.. autoclass:: RandAugment .. autoclass:: RandAugment
:members: :members:
`TrivialAugmentWide <https://arxiv.org/abs/2103.10158>`_ is a dataset-independent data-augmentation technique which improves the accuracy of Image Classification models.
.. autoclass:: TrivialAugmentWide
:members:
.. _functional_transforms: .. _functional_transforms:
Functional Transforms Functional Transforms
......
...@@ -253,6 +253,14 @@ augmenter = T.RandAugment() ...@@ -253,6 +253,14 @@ augmenter = T.RandAugment()
imgs = [augmenter(orig_img) for _ in range(4)] imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs) plot(imgs)
####################################
# TrivialAugmentWide
# ~~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.TrivialAugmentWide` transform automatically augments the data.
augmenter = T.TrivialAugmentWide()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)
#################################### ####################################
# Randomly-applied transforms # Randomly-applied transforms
# --------------------------- # ---------------------------
......
...@@ -11,6 +11,8 @@ class ClassificationPresetTrain: ...@@ -11,6 +11,8 @@ class ClassificationPresetTrain:
if auto_augment_policy is not None: if auto_augment_policy is not None:
if auto_augment_policy == "ra": if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment()) trans.append(autoaugment.RandAugment())
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide())
else: else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy)) trans.append(autoaugment.AutoAugment(policy=aa_policy))
......
...@@ -1502,6 +1502,17 @@ def test_randaugment(num_ops, magnitude, fill): ...@@ -1502,6 +1502,17 @@ def test_randaugment(num_ops, magnitude, fill):
transform.__repr__() transform.__repr__()
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
@pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30])
def test_trivialaugmentwide(fill, num_magnitude_bins):
random.seed(42)
img = Image.open(GRACE_HOPPER)
transform = transforms.TrivialAugmentWide(fill=fill, num_magnitude_bins=num_magnitude_bins)
for _ in range(100):
img = transform(img)
transform.__repr__()
def test_random_crop(): def test_random_crop():
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2
......
...@@ -547,7 +547,20 @@ def test_randaugment(device, num_ops, magnitude, fill): ...@@ -547,7 +547,20 @@ def test_randaugment(device, num_ops, magnitude, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment]) @pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1])
def test_trivialaugmentwide(device, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
transform = T.TrivialAugmentWide(fill=fill)
s_transform = torch.jit.script(transform)
for _ in range(25):
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide])
def test_autoaugment_save(augmentation, tmpdir): def test_autoaugment_save(augmentation, tmpdir):
transform = augmentation() transform = augmentation()
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
......
...@@ -7,7 +7,7 @@ from typing import List, Tuple, Optional, Dict ...@@ -7,7 +7,7 @@ from typing import List, Tuple, Optional, Dict
from . import functional as F, InterpolationMode from . import functional as F, InterpolationMode
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment"] __all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]
def _apply_op(img: Tensor, op_name: str, magnitude: float, def _apply_op(img: Tensor, op_name: str, magnitude: float,
...@@ -44,6 +44,8 @@ def _apply_op(img: Tensor, op_name: str, magnitude: float, ...@@ -44,6 +44,8 @@ def _apply_op(img: Tensor, op_name: str, magnitude: float,
img = F.equalize(img) img = F.equalize(img)
elif op_name == "Invert": elif op_name == "Invert":
img = F.invert(img) img = F.invert(img)
elif op_name == "Identity":
pass
else: else:
raise ValueError("The provided operator {} is not recognized.".format(op_name)) raise ValueError("The provided operator {} is not recognized.".format(op_name))
return img return img
...@@ -325,3 +327,79 @@ class RandAugment(torch.nn.Module): ...@@ -325,3 +327,79 @@ class RandAugment(torch.nn.Module):
s += ', fill={fill}' s += ', fill={fill}'
s += ')' s += ')'
return s.format(**self.__dict__) return s.format(**self.__dict__)
class TrivialAugmentWide(torch.nn.Module):
r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
num_magnitude_bins (int): The number of different magnitude values.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None) -> None:
super().__init__()
self.num_magnitude_bins = num_magnitude_bins
self.interpolation = interpolation
self.fill = fill
def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
return {
# op_name: (magnitudes, signed)
"Identity": (torch.tensor(0.0), False),
"ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
"TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
"TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
"Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
"Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
"Color": (torch.linspace(0.0, 0.99, num_bins), True),
"Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(0.0), False),
"Equalize": (torch.tensor(0.0), False),
}
def forward(self, img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img)
elif fill is not None:
fill = [float(f) for f in fill]
op_meta = self._augmentation_space(self.num_magnitude_bins)
op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index]
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \
if magnitudes.ndim > 0 else 0.0
if signed and torch.randint(2, (1,)):
magnitude *= -1.0
return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_magnitude_bins={num_magnitude_bins}'
s += ', interpolation={interpolation}'
s += ', fill={fill}'
s += ')'
return s.format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment