Unverified Commit 3e4e353d authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Cutmix -> CutMix (#7784)

parent edde8255
...@@ -274,8 +274,8 @@ are combining pairs of images together. These can be used after the dataloader ...@@ -274,8 +274,8 @@ are combining pairs of images together. These can be used after the dataloader
:toctree: generated/ :toctree: generated/
:template: class.rst :template: class.rst
v2.Cutmix v2.CutMix
v2.Mixup v2.MixUp
.. _functional_transforms: .. _functional_transforms:
......
...@@ -13,15 +13,15 @@ def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2): ...@@ -13,15 +13,15 @@ def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2):
mixup_cutmix = [] mixup_cutmix = []
if mixup_alpha > 0: if mixup_alpha > 0:
mixup_cutmix.append( mixup_cutmix.append(
transforms_module.Mixup(alpha=mixup_alpha, num_categories=num_categories) transforms_module.MixUp(alpha=mixup_alpha, num_categories=num_categories)
if use_v2 if use_v2
else RandomMixup(num_classes=num_categories, p=1.0, alpha=mixup_alpha) else RandomMixUp(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
) )
if cutmix_alpha > 0: if cutmix_alpha > 0:
mixup_cutmix.append( mixup_cutmix.append(
transforms_module.Cutmix(alpha=mixup_alpha, num_categories=num_categories) transforms_module.CutMix(alpha=mixup_alpha, num_categories=num_categories)
if use_v2 if use_v2
else RandomCutmix(num_classes=num_categories, p=1.0, alpha=mixup_alpha) else RandomCutMix(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
) )
if not mixup_cutmix: if not mixup_cutmix:
return None return None
...@@ -29,8 +29,8 @@ def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2): ...@@ -29,8 +29,8 @@ def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2):
return transforms_module.RandomChoice(mixup_cutmix) return transforms_module.RandomChoice(mixup_cutmix)
class RandomMixup(torch.nn.Module): class RandomMixUp(torch.nn.Module):
"""Randomly apply Mixup to the provided batch and targets. """Randomly apply MixUp to the provided batch and targets.
The class implements the data augmentations as described in the paper The class implements the data augmentations as described in the paper
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_. `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
...@@ -112,8 +112,8 @@ class RandomMixup(torch.nn.Module): ...@@ -112,8 +112,8 @@ class RandomMixup(torch.nn.Module):
return s return s
class RandomCutmix(torch.nn.Module): class RandomCutMix(torch.nn.Module):
"""Randomly apply Cutmix to the provided batch and targets. """Randomly apply CutMix to the provided batch and targets.
The class implements the data augmentations as described in the paper The class implements the data augmentations as described in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
<https://arxiv.org/abs/1905.04899>`_. <https://arxiv.org/abs/1905.04899>`_.
......
...@@ -60,8 +60,8 @@ def parametrize(transforms_with_inputs): ...@@ -60,8 +60,8 @@ def parametrize(transforms_with_inputs):
], ],
) )
for transform in [ for transform in [
transforms.RandomMixup(alpha=1.0), transforms.RandomMixUp(alpha=1.0),
transforms.RandomCutmix(alpha=1.0), transforms.RandomCutMix(alpha=1.0),
] ]
] ]
) )
......
...@@ -1914,7 +1914,7 @@ class TestCutMixMixUp: ...@@ -1914,7 +1914,7 @@ class TestCutMixMixUp:
def __len__(self): def __len__(self):
return self.size return self.size
@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup]) @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
def test_supported_input_structure(self, T): def test_supported_input_structure(self, T):
batch_size = 32 batch_size = 32
...@@ -1964,7 +1964,7 @@ class TestCutMixMixUp: ...@@ -1964,7 +1964,7 @@ class TestCutMixMixUp:
check_output(img, target) check_output(img, target)
@needs_cuda @needs_cuda
@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup]) @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
def test_cpu_vs_gpu(self, T): def test_cpu_vs_gpu(self, T):
num_classes = 10 num_classes = 10
batch_size = 3 batch_size = 3
...@@ -1976,7 +1976,7 @@ class TestCutMixMixUp: ...@@ -1976,7 +1976,7 @@ class TestCutMixMixUp:
_check_kernel_cuda_vs_cpu(cutmix_mixup, imgs, labels, rtol=None, atol=None) _check_kernel_cuda_vs_cpu(cutmix_mixup, imgs, labels, rtol=None, atol=None)
@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup]) @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
def test_error(self, T): def test_error(self, T):
num_classes = 10 num_classes = 10
......
from ._presets import StereoMatching # usort: skip from ._presets import StereoMatching # usort: skip
from ._augment import RandomCutmix, RandomMixup, SimpleCopyPaste from ._augment import RandomCutMix, RandomMixUp, SimpleCopyPaste
from ._geometry import FixedSizeCrop from ._geometry import FixedSizeCrop
from ._misc import PermuteDimensions, TransposeDimensions from ._misc import PermuteDimensions, TransposeDimensions
from ._type_conversion import LabelToOneHot from ._type_conversion import LabelToOneHot
...@@ -14,7 +14,7 @@ from torchvision.transforms.v2.functional._geometry import _check_interpolation ...@@ -14,7 +14,7 @@ from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_size from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_size
class _BaseMixupCutmix(_RandomApplyTransform): class _BaseMixUpCutMix(_RandomApplyTransform):
def __init__(self, alpha: float, p: float = 0.5) -> None: def __init__(self, alpha: float, p: float = 0.5) -> None:
super().__init__(p=p) super().__init__(p=p)
self.alpha = alpha self.alpha = alpha
...@@ -38,7 +38,7 @@ class _BaseMixupCutmix(_RandomApplyTransform): ...@@ -38,7 +38,7 @@ class _BaseMixupCutmix(_RandomApplyTransform):
return proto_datapoints.OneHotLabel.wrap_like(inpt, output) return proto_datapoints.OneHotLabel.wrap_like(inpt, output)
class RandomMixup(_BaseMixupCutmix): class RandomMixUp(_BaseMixUpCutMix):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type] return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type]
...@@ -60,7 +60,7 @@ class RandomMixup(_BaseMixupCutmix): ...@@ -60,7 +60,7 @@ class RandomMixup(_BaseMixupCutmix):
return inpt return inpt
class RandomCutmix(_BaseMixupCutmix): class RandomCutMix(_BaseMixUpCutMix):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
lam = float(self._dist.sample(())) # type: ignore[arg-type] lam = float(self._dist.sample(())) # type: ignore[arg-type]
......
...@@ -4,7 +4,7 @@ from . import functional, utils # usort: skip ...@@ -4,7 +4,7 @@ from . import functional, utils # usort: skip
from ._transform import Transform # usort: skip from ._transform import Transform # usort: skip
from ._augment import Cutmix, Mixup, RandomErasing from ._augment import CutMix, MixUp, RandomErasing
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import ( from ._color import (
ColorJitter, ColorJitter,
......
...@@ -140,7 +140,7 @@ class RandomErasing(_RandomApplyTransform): ...@@ -140,7 +140,7 @@ class RandomErasing(_RandomApplyTransform):
return inpt return inpt
class _BaseMixupCutmix(Transform): class _BaseMixUpCutMix(Transform):
def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None: def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None:
super().__init__() super().__init__()
self.alpha = float(alpha) self.alpha = float(alpha)
...@@ -203,10 +203,10 @@ class _BaseMixupCutmix(Transform): ...@@ -203,10 +203,10 @@ class _BaseMixupCutmix(Transform):
return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam)) return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))
class Mixup(_BaseMixupCutmix): class MixUp(_BaseMixUpCutMix):
"""[BETA] Apply MixUp to the provided batch of images and labels. """[BETA] Apply MixUp to the provided batch of images and labels.
.. v2betastatus:: Mixup transform .. v2betastatus:: MixUp transform
Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_. Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.
...@@ -227,7 +227,7 @@ class Mixup(_BaseMixupCutmix): ...@@ -227,7 +227,7 @@ class Mixup(_BaseMixupCutmix):
num_classes (int): number of classes in the batch. Used for one-hot-encoding. num_classes (int): number of classes in the batch. Used for one-hot-encoding.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input. labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter a the labels if it's a tensor. This covers the most By default, this will pick the second parameter a the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``Mixup()(imgs_batch, labels_batch)``. common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels. It can also be a callable that takes the same input as the transform, and returns the labels.
""" """
...@@ -252,10 +252,10 @@ class Mixup(_BaseMixupCutmix): ...@@ -252,10 +252,10 @@ class Mixup(_BaseMixupCutmix):
return inpt return inpt
class Cutmix(_BaseMixupCutmix): class CutMix(_BaseMixUpCutMix):
"""[BETA] Apply CutMix to the provided batch of images and labels. """[BETA] Apply CutMix to the provided batch of images and labels.
.. v2betastatus:: Cutmix transform .. v2betastatus:: CutMix transform
Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
<https://arxiv.org/abs/1905.04899>`_. <https://arxiv.org/abs/1905.04899>`_.
...@@ -277,7 +277,7 @@ class Cutmix(_BaseMixupCutmix): ...@@ -277,7 +277,7 @@ class Cutmix(_BaseMixupCutmix):
num_classes (int): number of classes in the batch. Used for one-hot-encoding. num_classes (int): number of classes in the batch. Used for one-hot-encoding.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input. labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter a the labels if it's a tensor. This covers the most By default, this will pick the second parameter a the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``Cutmix()(imgs_batch, labels_batch)``. common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels. It can also be a callable that takes the same input as the transform, and returns the labels.
""" """
......
...@@ -89,7 +89,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor: ...@@ -89,7 +89,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
This heuristic covers three cases: This heuristic covers three cases:
1. The input is tuple or list whose second item is a labels tensor. This happens for already batched 1. The input is tuple or list whose second item is a labels tensor. This happens for already batched
classification inputs for Mixup and Cutmix (typically after the Dataloder). classification inputs for MixUp and CutMix (typically after the Dataloder).
2. The input is a tuple or list whose second item is a dictionary that contains the labels tensor 2. The input is a tuple or list whose second item is a dictionary that contains the labels tensor
under a label-like (see below) key. This happens for the inputs of detection models. under a label-like (see below) key. This happens for the inputs of detection models.
3. The input is a dictionary that is structured as the one from 2. 3. The input is a dictionary that is structured as the one from 2.
...@@ -103,7 +103,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor: ...@@ -103,7 +103,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
inputs = inputs[1] inputs = inputs[1]
# Mixup, Cutmix # MixUp, CutMix
if isinstance(inputs, torch.Tensor): if isinstance(inputs, torch.Tensor):
return inputs return inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment