Unverified Commit c8e3b2a5 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding Mixup and Cutmix (#4379)

* Add RandomMixupCutmix.

* Add test with real data.

* Use dataloader and collate in the test.

* Making RandomMixupCutmix JIT scriptable.

* Move out label_smoothing and try roll instead of flip

* Adding mixup/cutmix in references script.

* Handle one-hot encoded target in accuracy.

* Add support of devices on tests.

* Separate Mixup from Cutmix.

* Add check for floats.

* Adding device on expect value.

* Remove hardcoded weights.

* One-hot only when necessary.

* Fix linter.

* Moving mixup and cutmix to references.

* Final code clean up.
parent b0962719
...@@ -4,11 +4,13 @@ import time ...@@ -4,11 +4,13 @@ import time
import torch import torch
import torch.utils.data import torch.utils.data
from torch.utils.data.dataloader import default_collate
from torch import nn from torch import nn
import torchvision import torchvision
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
import presets import presets
import transforms
import utils import utils
try: try:
...@@ -164,10 +166,21 @@ def main(args): ...@@ -164,10 +166,21 @@ def main(args):
train_dir = os.path.join(args.data_path, 'train') train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val') val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
collate_fn = None
num_classes = len(dataset.classes)
mixup_transforms = []
if args.mixup_alpha > 0.0:
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
if args.cutmix_alpha > 0.0:
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
if mixup_transforms:
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size, dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True) sampler=train_sampler, num_workers=args.workers, pin_memory=True,
collate_fn=collate_fn)
data_loader_test = torch.utils.data.DataLoader( data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size, dataset_test, batch_size=args.batch_size,
sampler=test_sampler, num_workers=args.workers, pin_memory=True) sampler=test_sampler, num_workers=args.workers, pin_memory=True)
...@@ -272,6 +285,8 @@ def get_args_parser(add_help=True): ...@@ -272,6 +285,8 @@ def get_args_parser(add_help=True):
parser.add_argument('--label-smoothing', default=0.0, type=float, parser.add_argument('--label-smoothing', default=0.0, type=float,
help='label smoothing (default: 0.0)', help='label smoothing (default: 0.0)',
dest='label_smoothing') dest='label_smoothing')
parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)')
parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)')
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
......
import math
import torch
from typing import Tuple
from torch import Tensor
from torchvision.transforms import functional as F
class RandomMixup(torch.nn.Module):
"""Randomly apply Mixup to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for mixup.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""
def __init__(self, num_classes: int,
p: float = 0.5, alpha: float = 1.0,
inplace: bool = False) -> None:
super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero."
self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )
Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim))
elif target.ndim != 1:
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
elif not batch.is_floating_point():
raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype))
elif target.dtype != torch.int64:
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype))
if not self.inplace:
batch = batch.clone()
target = target.clone()
if target.ndim == 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
if torch.rand(1).item() >= self.p:
return batch, target
# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)
# Implemented as on mixup paper, page 3.
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
batch_rolled.mul_(1.0 - lambda_param)
batch.mul_(lambda_param).add_(batch_rolled)
target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)
return batch, target
def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_classes={num_classes}'
s += ', p={p}'
s += ', alpha={alpha}'
s += ', inplace={inplace}'
s += ')'
return s.format(**self.__dict__)
class RandomCutmix(torch.nn.Module):
"""Randomly apply Cutmix to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
<https://arxiv.org/abs/1905.04899>`_.
Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for cutmix.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""
def __init__(self, num_classes: int,
p: float = 0.5, alpha: float = 1.0,
inplace: bool = False) -> None:
super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero."
self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )
Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim))
elif target.ndim != 1:
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
elif not batch.is_floating_point():
raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype))
elif target.dtype != torch.int64:
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype))
if not self.inplace:
batch = batch.clone()
target = target.clone()
if target.ndim == 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
if torch.rand(1).item() >= self.p:
return batch, target
# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
W, H = F.get_image_size(batch)
r_x = torch.randint(W, (1,))
r_y = torch.randint(H, (1,))
r = 0.5 * math.sqrt(1.0 - lambda_param)
r_w_half = int(r * W)
r_h_half = int(r * H)
x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)
return batch, target
def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_classes={num_classes}'
s += ', p={p}'
s += ', alpha={alpha}'
s += ', inplace={inplace}'
s += ')'
return s.format(**self.__dict__)
...@@ -189,6 +189,8 @@ def accuracy(output, target, topk=(1,)): ...@@ -189,6 +189,8 @@ def accuracy(output, target, topk=(1,)):
with torch.no_grad(): with torch.no_grad():
maxk = max(topk) maxk = max(topk)
batch_size = target.size(0) batch_size = target.size(0)
if target.ndim == 2:
target = target.max(dim=1)[1]
_, pred = output.topk(maxk, 1, True, True) _, pred = output.topk(maxk, 1, True, True)
pred = pred.t() pred = pred.t()
......
...@@ -1311,7 +1311,8 @@ def test_random_choice(): ...@@ -1311,7 +1311,8 @@ def test_random_choice():
transforms.Resize(15), transforms.Resize(15),
transforms.Resize(20), transforms.Resize(20),
transforms.CenterCrop(10) transforms.CenterCrop(10)
] ],
[1 / 3, 1 / 3, 1 / 3]
) )
img = transforms.ToPILImage()(torch.rand(3, 25, 25)) img = transforms.ToPILImage()(torch.rand(3, 25, 25))
num_samples = 250 num_samples = 250
......
...@@ -515,9 +515,20 @@ class RandomOrder(RandomTransforms): ...@@ -515,9 +515,20 @@ class RandomOrder(RandomTransforms):
class RandomChoice(RandomTransforms): class RandomChoice(RandomTransforms):
"""Apply single transformation randomly picked from a list. This transform does not support torchscript. """Apply single transformation randomly picked from a list. This transform does not support torchscript.
""" """
def __call__(self, img): def __init__(self, transforms, p=None):
t = random.choice(self.transforms) super().__init__(transforms)
return t(img) if p is not None and not isinstance(p, Sequence):
raise TypeError("Argument transforms should be a sequence")
self.p = p
def __call__(self, *args):
t = random.choices(self.transforms, weights=self.p)[0]
return t(*args)
def __repr__(self):
format_string = super().__repr__()
format_string += '(p={0})'.format(self.p)
return format_string
class RandomCrop(torch.nn.Module): class RandomCrop(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment