Unverified Commit c585a515 authored by Mahdi Lamb's avatar Mahdi Lamb Committed by GitHub
Browse files

Enable one-hot-encoded labels in MixUp and CutMix (#8427)


Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent 778ce48b
...@@ -2169,26 +2169,30 @@ class TestAdjustBrightness: ...@@ -2169,26 +2169,30 @@ class TestAdjustBrightness:
class TestCutMixMixUp: class TestCutMixMixUp:
class DummyDataset: class DummyDataset:
def __init__(self, size, num_classes): def __init__(self, size, num_classes, one_hot_labels):
self.size = size self.size = size
self.num_classes = num_classes self.num_classes = num_classes
self.one_hot_labels = one_hot_labels
assert size < num_classes assert size < num_classes
def __getitem__(self, idx): def __getitem__(self, idx):
img = torch.rand(3, 100, 100) img = torch.rand(3, 100, 100)
label = idx # This ensures all labels in a batch are unique and makes testing easier label = idx # This ensures all labels in a batch are unique and makes testing easier
if self.one_hot_labels:
label = torch.nn.functional.one_hot(torch.tensor(label), num_classes=self.num_classes)
return img, label return img, label
def __len__(self): def __len__(self):
return self.size return self.size
@pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp]) @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
def test_supported_input_structure(self, T): @pytest.mark.parametrize("one_hot_labels", (True, False))
def test_supported_input_structure(self, T, one_hot_labels):
batch_size = 32 batch_size = 32
num_classes = 100 num_classes = 100
dataset = self.DummyDataset(size=batch_size, num_classes=num_classes) dataset = self.DummyDataset(size=batch_size, num_classes=num_classes, one_hot_labels=one_hot_labels)
cutmix_mixup = T(num_classes=num_classes) cutmix_mixup = T(num_classes=num_classes)
...@@ -2198,7 +2202,7 @@ class TestCutMixMixUp: ...@@ -2198,7 +2202,7 @@ class TestCutMixMixUp:
img, target = next(iter(dl)) img, target = next(iter(dl))
input_img_size = img.shape[-3:] input_img_size = img.shape[-3:]
assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor) assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor)
assert target.shape == (batch_size,) assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,)
def check_output(img, target): def check_output(img, target):
assert img.shape == (batch_size, *input_img_size) assert img.shape == (batch_size, *input_img_size)
...@@ -2209,7 +2213,7 @@ class TestCutMixMixUp: ...@@ -2209,7 +2213,7 @@ class TestCutMixMixUp:
# After Dataloader, as unpacked input # After Dataloader, as unpacked input
img, target = next(iter(dl)) img, target = next(iter(dl))
assert target.shape == (batch_size,) assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,)
img, target = cutmix_mixup(img, target) img, target = cutmix_mixup(img, target)
check_output(img, target) check_output(img, target)
...@@ -2264,7 +2268,7 @@ class TestCutMixMixUp: ...@@ -2264,7 +2268,7 @@ class TestCutMixMixUp:
with pytest.raises(ValueError, match="Could not infer where the labels are"): with pytest.raises(ValueError, match="Could not infer where the labels are"):
cutmix_mixup({"img": imgs, "Nothing_else": 3}) cutmix_mixup({"img": imgs, "Nothing_else": 3})
with pytest.raises(ValueError, match="labels tensor should be of shape"): with pytest.raises(ValueError, match="labels should be index based"):
# Note: the error message isn't ideal, but that's because the label heuristic found the img as the label # Note: the error message isn't ideal, but that's because the label heuristic found the img as the label
# It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently # It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently
cutmix_mixup(imgs) cutmix_mixup(imgs)
...@@ -2272,22 +2276,21 @@ class TestCutMixMixUp: ...@@ -2272,22 +2276,21 @@ class TestCutMixMixUp:
with pytest.raises(ValueError, match="When using the default labels_getter"): with pytest.raises(ValueError, match="When using the default labels_getter"):
cutmix_mixup(imgs, "not_a_tensor") cutmix_mixup(imgs, "not_a_tensor")
with pytest.raises(ValueError, match="labels tensor should be of shape"):
cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3)))
with pytest.raises(ValueError, match="Expected a batched input with 4 dims"): with pytest.raises(ValueError, match="Expected a batched input with 4 dims"):
cutmix_mixup(imgs[None, None], torch.randint(0, num_classes, size=(batch_size,))) cutmix_mixup(imgs[None, None], torch.randint(0, num_classes, size=(batch_size,)))
with pytest.raises(ValueError, match="does not match the batch size of the labels"): with pytest.raises(ValueError, match="does not match the batch size of the labels"):
cutmix_mixup(imgs, torch.randint(0, num_classes, size=(batch_size + 1,))) cutmix_mixup(imgs, torch.randint(0, num_classes, size=(batch_size + 1,)))
with pytest.raises(ValueError, match="labels tensor should be of shape"): with pytest.raises(ValueError, match="When passing 2D labels"):
# The purpose of this check is more about documenting the current wrong_num_classes = num_classes + 1
# behaviour of what happens on a Compose(), rather than actually T(alpha=0.5, num_classes=num_classes)(imgs, torch.randint(0, 2, size=(batch_size, wrong_num_classes)))
# asserting the expected behaviour. We may support Compose() in the
# future, e.g. for 2 consecutive CutMix? with pytest.raises(ValueError, match="but got a tensor of shape"):
labels = torch.randint(0, num_classes, size=(batch_size,)) cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3, 4)))
transforms.Compose([cutmix_mixup, cutmix_mixup])(imgs, labels)
with pytest.raises(ValueError, match="num_classes must be passed"):
T(alpha=0.5)(imgs, torch.randint(0, num_classes, size=(batch_size,)))
@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT")) @pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
......
import math import math
import numbers import numbers
import warnings import warnings
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -142,7 +142,7 @@ class RandomErasing(_RandomApplyTransform): ...@@ -142,7 +142,7 @@ class RandomErasing(_RandomApplyTransform):
class _BaseMixUpCutMix(Transform): class _BaseMixUpCutMix(Transform):
def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None: def __init__(self, *, alpha: float = 1.0, num_classes: Optional[int] = None, labels_getter="default") -> None:
super().__init__() super().__init__()
self.alpha = float(alpha) self.alpha = float(alpha)
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
...@@ -162,10 +162,21 @@ class _BaseMixUpCutMix(Transform): ...@@ -162,10 +162,21 @@ class _BaseMixUpCutMix(Transform):
labels = self._labels_getter(inputs) labels = self._labels_getter(inputs)
if not isinstance(labels, torch.Tensor): if not isinstance(labels, torch.Tensor):
raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.") raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.")
elif labels.ndim != 1: if labels.ndim not in (1, 2):
raise ValueError( raise ValueError(
f"labels tensor should be of shape (batch_size,) " f"but got shape {labels.shape} instead." f"labels should be index based with shape (batch_size,) "
f"or probability based with shape (batch_size, num_classes), "
f"but got a tensor of shape {labels.shape} instead."
) )
if labels.ndim == 2 and self.num_classes is not None and labels.shape[-1] != self.num_classes:
raise ValueError(
f"When passing 2D labels, "
f"the number of elements in last dimension must match num_classes: "
f"{labels.shape[-1]} != {self.num_classes}. "
f"You can Leave num_classes to None."
)
if labels.ndim == 1 and self.num_classes is None:
raise ValueError("num_classes must be passed if the labels are index-based (1D)")
params = { params = {
"labels": labels, "labels": labels,
...@@ -198,7 +209,8 @@ class _BaseMixUpCutMix(Transform): ...@@ -198,7 +209,8 @@ class _BaseMixUpCutMix(Transform):
) )
def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor: def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
label = one_hot(label, num_classes=self.num_classes) if label.ndim == 1:
label = one_hot(label, num_classes=self.num_classes) # type: ignore[arg-type]
if not label.dtype.is_floating_point: if not label.dtype.is_floating_point:
label = label.float() label = label.float()
return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam)) return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))
...@@ -223,7 +235,8 @@ class MixUp(_BaseMixUpCutMix): ...@@ -223,7 +235,8 @@ class MixUp(_BaseMixUpCutMix):
Args: Args:
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int): number of classes in the batch. Used for one-hot-encoding. num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
Can be None only if the labels are already one-hot-encoded.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input. labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``. common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
...@@ -271,7 +284,8 @@ class CutMix(_BaseMixUpCutMix): ...@@ -271,7 +284,8 @@ class CutMix(_BaseMixUpCutMix):
Args: Args:
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int): number of classes in the batch. Used for one-hot-encoding. num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
Can be None only if the labels are already one-hot-encoded.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input. labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``. common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment