"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b45204ea5aa0160d343c79bfb19ec9ceda637a5b"
Unverified Commit 446b2ca5 authored by SamuelGabriel's avatar SamuelGabriel Committed by GitHub
Browse files

Integration of TrivialAugment with the current AutoAugment Code (#4221)



* Initial Proposal

* Tensor Save Test + Test Name Fix

* Formatting + removing unused argument

* fix old argument

* fix isnan check error + indexing error with jit

* Fix Flake8 error.

* Fix MyPy error.

* Fix Flake8 error.

* Fix PyTorch JIT error in UnitTests due to type annotation.

* Fixing tests.

* Removing type ignore.

* Adding support of ta_wide in references.

* Move methods in classes.

* Moving new classes to the bottom.

* Specialize to TA to TAwide

* Add missing type

* Fixing lint

* Fix doc

* Fix search space of TrivialAugment.
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarVasilis Vryniotis <vvryniotis@fb.com>
parent 80d5f50d
...@@ -234,6 +234,11 @@ The new transform can be used standalone or mixed-and-matched with existing tran ...@@ -234,6 +234,11 @@ The new transform can be used standalone or mixed-and-matched with existing tran
.. autoclass:: RandAugment .. autoclass:: RandAugment
:members: :members:
`TrivialAugmentWide <https://arxiv.org/abs/2103.10158>`_ is a dataset-independent data-augmentation technique which improves the accuracy of Image Classification models.
.. autoclass:: TrivialAugmentWide
:members:
.. _functional_transforms: .. _functional_transforms:
Functional Transforms Functional Transforms
......
...@@ -253,6 +253,14 @@ augmenter = T.RandAugment() ...@@ -253,6 +253,14 @@ augmenter = T.RandAugment()
imgs = [augmenter(orig_img) for _ in range(4)] imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs) plot(imgs)
####################################
# TrivialAugmentWide
# ~~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.TrivialAugmentWide` transform automatically augments the data.
augmenter = T.TrivialAugmentWide()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)
#################################### ####################################
# Randomly-applied transforms # Randomly-applied transforms
# --------------------------- # ---------------------------
......
...@@ -11,6 +11,8 @@ class ClassificationPresetTrain: ...@@ -11,6 +11,8 @@ class ClassificationPresetTrain:
if auto_augment_policy is not None: if auto_augment_policy is not None:
if auto_augment_policy == "ra": if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment()) trans.append(autoaugment.RandAugment())
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide())
else: else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy)) trans.append(autoaugment.AutoAugment(policy=aa_policy))
......
...@@ -1502,6 +1502,17 @@ def test_randaugment(num_ops, magnitude, fill): ...@@ -1502,6 +1502,17 @@ def test_randaugment(num_ops, magnitude, fill):
transform.__repr__() transform.__repr__()
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
@pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30])
def test_trivialaugmentwide(fill, num_magnitude_bins):
random.seed(42)
img = Image.open(GRACE_HOPPER)
transform = transforms.TrivialAugmentWide(fill=fill, num_magnitude_bins=num_magnitude_bins)
for _ in range(100):
img = transform(img)
transform.__repr__()
def test_random_crop(): def test_random_crop():
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2
......
...@@ -547,7 +547,20 @@ def test_randaugment(device, num_ops, magnitude, fill): ...@@ -547,7 +547,20 @@ def test_randaugment(device, num_ops, magnitude, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment]) @pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1])
def test_trivialaugmentwide(device, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
transform = T.TrivialAugmentWide(fill=fill)
s_transform = torch.jit.script(transform)
for _ in range(25):
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide])
def test_autoaugment_save(augmentation, tmpdir): def test_autoaugment_save(augmentation, tmpdir):
transform = augmentation() transform = augmentation()
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
......
...@@ -7,7 +7,7 @@ from typing import List, Tuple, Optional, Dict ...@@ -7,7 +7,7 @@ from typing import List, Tuple, Optional, Dict
from . import functional as F, InterpolationMode from . import functional as F, InterpolationMode
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment"] __all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]
def _apply_op(img: Tensor, op_name: str, magnitude: float, def _apply_op(img: Tensor, op_name: str, magnitude: float,
...@@ -44,6 +44,8 @@ def _apply_op(img: Tensor, op_name: str, magnitude: float, ...@@ -44,6 +44,8 @@ def _apply_op(img: Tensor, op_name: str, magnitude: float,
img = F.equalize(img) img = F.equalize(img)
elif op_name == "Invert": elif op_name == "Invert":
img = F.invert(img) img = F.invert(img)
elif op_name == "Identity":
pass
else: else:
raise ValueError("The provided operator {} is not recognized.".format(op_name)) raise ValueError("The provided operator {} is not recognized.".format(op_name))
return img return img
...@@ -325,3 +327,79 @@ class RandAugment(torch.nn.Module): ...@@ -325,3 +327,79 @@ class RandAugment(torch.nn.Module):
s += ', fill={fill}' s += ', fill={fill}'
s += ')' s += ')'
return s.format(**self.__dict__) return s.format(**self.__dict__)
class TrivialAugmentWide(torch.nn.Module):
r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
num_magnitude_bins (int): The number of different magnitude values.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None) -> None:
super().__init__()
self.num_magnitude_bins = num_magnitude_bins
self.interpolation = interpolation
self.fill = fill
def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
return {
# op_name: (magnitudes, signed)
"Identity": (torch.tensor(0.0), False),
"ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
"TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
"TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
"Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
"Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
"Color": (torch.linspace(0.0, 0.99, num_bins), True),
"Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(0.0), False),
"Equalize": (torch.tensor(0.0), False),
}
def forward(self, img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img)
elif fill is not None:
fill = [float(f) for f in fill]
op_meta = self._augmentation_space(self.num_magnitude_bins)
op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index]
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \
if magnitudes.ndim > 0 else 0.0
if signed and torch.randint(2, (1,)):
magnitude *= -1.0
return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_magnitude_bins={num_magnitude_bins}'
s += ', interpolation={interpolation}'
s += ', fill={fill}'
s += ')'
return s.format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment