Unverified Commit 3e4e353d authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Cutmix -> CutMix (#7784)

parent edde8255
......@@ -274,8 +274,8 @@ are combining pairs of images together. These can be used after the dataloader
:toctree: generated/
:template: class.rst
v2.Cutmix
v2.Mixup
v2.CutMix
v2.MixUp
.. _functional_transforms:
......
......@@ -13,15 +13,15 @@ def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2):
mixup_cutmix = []
if mixup_alpha > 0:
mixup_cutmix.append(
transforms_module.Mixup(alpha=mixup_alpha, num_categories=num_categories)
transforms_module.MixUp(alpha=mixup_alpha, num_categories=num_categories)
if use_v2
else RandomMixup(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
else RandomMixUp(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
)
if cutmix_alpha > 0:
mixup_cutmix.append(
transforms_module.Cutmix(alpha=mixup_alpha, num_categories=num_categories)
transforms_module.CutMix(alpha=mixup_alpha, num_categories=num_categories)
if use_v2
else RandomCutmix(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
else RandomCutMix(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
)
if not mixup_cutmix:
return None
......@@ -29,8 +29,8 @@ def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2):
return transforms_module.RandomChoice(mixup_cutmix)
class RandomMixup(torch.nn.Module):
"""Randomly apply Mixup to the provided batch and targets.
class RandomMixUp(torch.nn.Module):
"""Randomly apply MixUp to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
......@@ -112,8 +112,8 @@ class RandomMixup(torch.nn.Module):
return s
class RandomCutmix(torch.nn.Module):
"""Randomly apply Cutmix to the provided batch and targets.
class RandomCutMix(torch.nn.Module):
"""Randomly apply CutMix to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
<https://arxiv.org/abs/1905.04899>`_.
......
......@@ -60,8 +60,8 @@ def parametrize(transforms_with_inputs):
],
)
for transform in [
transforms.RandomMixup(alpha=1.0),
transforms.RandomCutmix(alpha=1.0),
transforms.RandomMixUp(alpha=1.0),
transforms.RandomCutMix(alpha=1.0),
]
]
)
......
......@@ -1914,7 +1914,7 @@ class TestCutMixMixUp:
def __len__(self):
return self.size
@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
@pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
def test_supported_input_structure(self, T):
batch_size = 32
......@@ -1964,7 +1964,7 @@ class TestCutMixMixUp:
check_output(img, target)
@needs_cuda
@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
@pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
def test_cpu_vs_gpu(self, T):
num_classes = 10
batch_size = 3
......@@ -1976,7 +1976,7 @@ class TestCutMixMixUp:
_check_kernel_cuda_vs_cpu(cutmix_mixup, imgs, labels, rtol=None, atol=None)
@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
@pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
def test_error(self, T):
num_classes = 10
......
from ._presets import StereoMatching # usort: skip
from ._augment import RandomCutmix, RandomMixup, SimpleCopyPaste
from ._augment import RandomCutMix, RandomMixUp, SimpleCopyPaste
from ._geometry import FixedSizeCrop
from ._misc import PermuteDimensions, TransposeDimensions
from ._type_conversion import LabelToOneHot
......@@ -14,7 +14,7 @@ from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_size
class _BaseMixupCutmix(_RandomApplyTransform):
class _BaseMixUpCutMix(_RandomApplyTransform):
def __init__(self, alpha: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.alpha = alpha
......@@ -38,7 +38,7 @@ class _BaseMixupCutmix(_RandomApplyTransform):
return proto_datapoints.OneHotLabel.wrap_like(inpt, output)
class RandomMixup(_BaseMixupCutmix):
class RandomMixUp(_BaseMixUpCutMix):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type]
......@@ -60,7 +60,7 @@ class RandomMixup(_BaseMixupCutmix):
return inpt
class RandomCutmix(_BaseMixupCutmix):
class RandomCutMix(_BaseMixUpCutMix):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
lam = float(self._dist.sample(())) # type: ignore[arg-type]
......
......@@ -4,7 +4,7 @@ from . import functional, utils # usort: skip
from ._transform import Transform # usort: skip
from ._augment import Cutmix, Mixup, RandomErasing
from ._augment import CutMix, MixUp, RandomErasing
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
......
......@@ -140,7 +140,7 @@ class RandomErasing(_RandomApplyTransform):
return inpt
class _BaseMixupCutmix(Transform):
class _BaseMixUpCutMix(Transform):
def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None:
super().__init__()
self.alpha = float(alpha)
......@@ -203,10 +203,10 @@ class _BaseMixupCutmix(Transform):
return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))
class Mixup(_BaseMixupCutmix):
class MixUp(_BaseMixUpCutMix):
"""[BETA] Apply MixUp to the provided batch of images and labels.
.. v2betastatus:: Mixup transform
.. v2betastatus:: MixUp transform
Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.
......@@ -227,7 +227,7 @@ class Mixup(_BaseMixupCutmix):
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter a the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``Mixup()(imgs_batch, labels_batch)``.
common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels.
"""
......@@ -252,10 +252,10 @@ class Mixup(_BaseMixupCutmix):
return inpt
class Cutmix(_BaseMixupCutmix):
class CutMix(_BaseMixUpCutMix):
"""[BETA] Apply CutMix to the provided batch of images and labels.
.. v2betastatus:: Cutmix transform
.. v2betastatus:: CutMix transform
Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
<https://arxiv.org/abs/1905.04899>`_.
......@@ -277,7 +277,7 @@ class Cutmix(_BaseMixupCutmix):
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter a the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``Cutmix()(imgs_batch, labels_batch)``.
common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels.
"""
......
......@@ -89,7 +89,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
This heuristic covers three cases:
1. The input is tuple or list whose second item is a labels tensor. This happens for already batched
classification inputs for Mixup and Cutmix (typically after the Dataloder).
classification inputs for MixUp and CutMix (typically after the Dataloder).
2. The input is a tuple or list whose second item is a dictionary that contains the labels tensor
under a label-like (see below) key. This happens for the inputs of detection models.
3. The input is a dictionary that is structured as the one from 2.
......@@ -103,7 +103,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
if isinstance(inputs, (tuple, list)):
inputs = inputs[1]
# Mixup, Cutmix
# MixUp, CutMix
if isinstance(inputs, torch.Tensor):
return inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment