Unverified Commit 5a815541 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding RandAugment implementation (#4348)

* Adding randaugment implementation

* Refactoring.

* Adding num_magnitude_bins.

* Adding FIXME.
parent f52ddb0c
...@@ -9,6 +9,9 @@ class ClassificationPresetTrain: ...@@ -9,6 +9,9 @@ class ClassificationPresetTrain:
if hflip_prob > 0: if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob)) trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None: if auto_augment_policy is not None:
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment())
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy)) trans.append(autoaugment.AutoAugment(policy=aa_policy))
trans.extend([ trans.extend([
......
...@@ -1490,6 +1490,18 @@ def test_autoaugment(policy, fill): ...@@ -1490,6 +1490,18 @@ def test_autoaugment(policy, fill):
transform.__repr__() transform.__repr__()
@pytest.mark.parametrize('num_ops', [1, 2, 3])
@pytest.mark.parametrize('magnitude', [7, 9, 11])
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
def test_randaugment(num_ops, magnitude, fill):
random.seed(42)
img = Image.open(GRACE_HOPPER)
transform = transforms.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
for _ in range(100):
img = transform(img)
transform.__repr__()
def test_random_crop(): def test_random_crop():
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2
......
...@@ -525,7 +525,6 @@ def test_autoaugment(device, policy, fill): ...@@ -525,7 +525,6 @@ def test_autoaugment(device, policy, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
s_transform = None
transform = T.AutoAugment(policy=policy, fill=fill) transform = T.AutoAugment(policy=policy, fill=fill)
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
for _ in range(25): for _ in range(25):
...@@ -533,8 +532,24 @@ def test_autoaugment(device, policy, fill): ...@@ -533,8 +532,24 @@ def test_autoaugment(device, policy, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_autoaugment_save(tmpdir): @pytest.mark.parametrize('device', cpu_and_gpu())
transform = T.AutoAugment() @pytest.mark.parametrize('num_ops', [1, 2, 3])
@pytest.mark.parametrize('magnitude', [7, 9, 11])
@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1])
def test_randaugment(device, num_ops, magnitude, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
transform = T.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
s_transform = torch.jit.script(transform)
for _ in range(25):
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment])
def test_autoaugment_save(augmentation, tmpdir):
transform = augmentation()
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt")) s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt"))
......
...@@ -7,7 +7,7 @@ from typing import List, Tuple, Optional, Dict ...@@ -7,7 +7,7 @@ from typing import List, Tuple, Optional, Dict
from . import functional as F, InterpolationMode from . import functional as F, InterpolationMode
__all__ = ["AutoAugmentPolicy", "AutoAugment"] __all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment"]
def _apply_op(img: Tensor, op_name: str, magnitude: float, def _apply_op(img: Tensor, op_name: str, magnitude: float,
...@@ -58,6 +58,7 @@ class AutoAugmentPolicy(Enum): ...@@ -58,6 +58,7 @@ class AutoAugmentPolicy(Enum):
SVHN = "svhn" SVHN = "svhn"
# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
class AutoAugment(torch.nn.Module): class AutoAugment(torch.nn.Module):
r"""AutoAugment data augmentation method based on r"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_. `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
...@@ -85,9 +86,9 @@ class AutoAugment(torch.nn.Module): ...@@ -85,9 +86,9 @@ class AutoAugment(torch.nn.Module):
self.policy = policy self.policy = policy
self.interpolation = interpolation self.interpolation = interpolation
self.fill = fill self.fill = fill
self.transforms = self._get_transforms(policy) self.policies = self._get_policies(policy)
def _get_transforms( def _get_policies(
self, self,
policy: AutoAugmentPolicy policy: AutoAugmentPolicy
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
...@@ -178,9 +179,9 @@ class AutoAugment(torch.nn.Module): ...@@ -178,9 +179,9 @@ class AutoAugment(torch.nn.Module):
else: else:
raise ValueError("The provided policy {} is not recognized.".format(policy)) raise ValueError("The provided policy {} is not recognized.".format(policy))
def _get_magnitudes(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
return { return {
# name: (magnitudes, signed) # op_name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True), "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
...@@ -224,11 +225,11 @@ class AutoAugment(torch.nn.Module): ...@@ -224,11 +225,11 @@ class AutoAugment(torch.nn.Module):
elif fill is not None: elif fill is not None:
fill = [float(f) for f in fill] fill = [float(f) for f in fill]
transform_id, probs, signs = self.get_params(len(self.transforms)) transform_id, probs, signs = self.get_params(len(self.policies))
for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]): for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
if probs[i] <= p: if probs[i] <= p:
op_meta = self._get_magnitudes(10, F.get_image_size(img)) op_meta = self._augmentation_space(10, F.get_image_size(img))
magnitudes, signed = op_meta[op_name] magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0 magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
if signed and signs[i] == 0: if signed and signs[i] == 0:
...@@ -239,3 +240,87 @@ class AutoAugment(torch.nn.Module): ...@@ -239,3 +240,87 @@ class AutoAugment(torch.nn.Module):
def __repr__(self) -> str: def __repr__(self) -> str:
return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill) return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill)
class RandAugment(torch.nn.Module):
r"""RandAugment data augmentation method based on
`"RandAugment: Practical automated data augmentation with a reduced search space"
<https://arxiv.org/abs/1909.13719>`.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
num_ops (int): Number of augmentation transformations to apply sequentially.
magnitude (int): Magnitude for all the transformations.
num_magnitude_bins (int): The number of different magnitude values.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 30,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None) -> None:
super().__init__()
self.num_ops = num_ops
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
self.interpolation = interpolation
self.fill = fill
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
return {
# op_name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(0.0), False),
"Equalize": (torch.tensor(0.0), False),
"Invert": (torch.tensor(0.0), False),
}
def forward(self, img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img)
elif fill is not None:
fill = [float(f) for f in fill]
for _ in range(self.num_ops):
op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img))
op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index]
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
if signed and torch.randint(2, (1,)):
magnitude *= -1.0
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
return img
def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_ops={num_ops}'
s += ', magnitude={magnitude}'
s += ', num_magnitude_bins={num_magnitude_bins}'
s += ', interpolation={interpolation}'
s += ', fill={fill}'
s += ')'
return s.format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment