Unverified Commit 35913710 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

promote Mixup and Cutmix from prototype to transforms v2 (#7731)


Co-authored-by: default avatarNicolas Hug <nicolashug@meta.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 8071c177
...@@ -261,6 +261,22 @@ The new transform can be used standalone or mixed-and-matched with existing tran ...@@ -261,6 +261,22 @@ The new transform can be used standalone or mixed-and-matched with existing tran
AugMix AugMix
v2.AugMix v2.AugMix
Cutmix - Mixup
--------------
Cutmix and Mixup are special transforms that
are meant to be used on batches rather than on individual images, because they
are combining pairs of images together. These can be used after the dataloader,
or part of a collation function. See
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.
.. autosummary::
:toctree: generated/
:template: class.rst
v2.Cutmix
v2.Mixup
.. _functional_transforms: .. _functional_transforms:
Functional Transforms Functional Transforms
......
"""
===========================
How to use Cutmix and Mixup
===========================
TODO
"""
...@@ -8,12 +8,12 @@ import torch ...@@ -8,12 +8,12 @@ import torch
import torch.utils.data import torch.utils.data
import torchvision import torchvision
import torchvision.transforms import torchvision.transforms
import transforms
import utils import utils
from sampler import RASampler from sampler import RASampler
from torch import nn from torch import nn
from torch.utils.data.dataloader import default_collate from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from transforms import get_mixup_cutmix
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
...@@ -218,18 +218,17 @@ def main(args): ...@@ -218,18 +218,17 @@ def main(args):
val_dir = os.path.join(args.data_path, "val") val_dir = os.path.join(args.data_path, "val")
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
collate_fn = None
num_classes = len(dataset.classes) num_classes = len(dataset.classes)
mixup_transforms = [] mixup_cutmix = get_mixup_cutmix(
if args.mixup_alpha > 0.0: mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_categories=num_classes, use_v2=args.use_v2
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) )
if args.cutmix_alpha > 0.0: if mixup_cutmix is not None:
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
if mixup_transforms:
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
def collate_fn(batch): def collate_fn(batch):
return mixupcutmix(*default_collate(batch)) return mixup_cutmix(*default_collate(batch))
else:
collate_fn = default_collate
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
dataset, dataset,
......
...@@ -2,10 +2,33 @@ import math ...@@ -2,10 +2,33 @@ import math
from typing import Tuple from typing import Tuple
import torch import torch
from presets import get_module
from torch import Tensor from torch import Tensor
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2):
transforms_module = get_module(use_v2)
mixup_cutmix = []
if mixup_alpha > 0:
mixup_cutmix.append(
transforms_module.Mixup(alpha=mixup_alpha, num_categories=num_categories)
if use_v2
else RandomMixup(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
)
if cutmix_alpha > 0:
mixup_cutmix.append(
transforms_module.Cutmix(alpha=mixup_alpha, num_categories=num_categories)
if use_v2
else RandomCutmix(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
)
if not mixup_cutmix:
return None
return transforms_module.RandomChoice(mixup_cutmix)
class RandomMixup(torch.nn.Module): class RandomMixup(torch.nn.Module):
"""Randomly apply Mixup to the provided batch and targets. """Randomly apply Mixup to the provided batch and targets.
The class implements the data augmentations as described in the paper The class implements the data augmentations as described in the paper
......
...@@ -1558,9 +1558,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): ...@@ -1558,9 +1558,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
@pytest.mark.parametrize("min_size", (1, 10)) @pytest.mark.parametrize("min_size", (1, 10))
@pytest.mark.parametrize( @pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None))
"labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None)
)
@pytest.mark.parametrize("sample_type", (tuple, dict)) @pytest.mark.parametrize("sample_type", (tuple, dict))
def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
...@@ -1648,22 +1646,6 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): ...@@ -1648,22 +1646,6 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
assert out_labels.tolist() == valid_indices assert out_labels.tolist() == valid_indices
@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_sanitize_bounding_boxes_default_heuristic(key, sample_type):
labels = torch.arange(10)
sample = {key: labels, "another_key": "whatever"}
if sample_type is tuple:
sample = (None, sample, "whatever_again")
assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(sample) is labels
if key.lower() != "labels":
# If "labels" is in the dict (case-insensitive),
# it takes precedence over other keys which would otherwise be a match
d = {key: "something_else", "labels": labels}
assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(d) is labels
def test_sanitize_bounding_boxes_errors(): def test_sanitize_bounding_boxes_errors():
good_bbox = datapoints.BoundingBox( good_bbox = datapoints.BoundingBox(
...@@ -1674,17 +1656,13 @@ def test_sanitize_bounding_boxes_errors(): ...@@ -1674,17 +1656,13 @@ def test_sanitize_bounding_boxes_errors():
with pytest.raises(ValueError, match="min_size must be >= 1"): with pytest.raises(ValueError, match="min_size must be >= 1"):
transforms.SanitizeBoundingBox(min_size=0) transforms.SanitizeBoundingBox(min_size=0)
with pytest.raises(ValueError, match="labels_getter should either be a str"): with pytest.raises(ValueError, match="labels_getter should either be 'default'"):
transforms.SanitizeBoundingBox(labels_getter=12) transforms.SanitizeBoundingBox(labels_getter=12)
with pytest.raises(ValueError, match="Could not infer where the labels are"): with pytest.raises(ValueError, match="Could not infer where the labels are"):
bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])} bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])}
transforms.SanitizeBoundingBox()(bad_labels_key) transforms.SanitizeBoundingBox()(bad_labels_key)
with pytest.raises(ValueError, match="If labels_getter is a str or 'default'"):
not_a_dict = (good_bbox, torch.arange(good_bbox.shape[0]))
transforms.SanitizeBoundingBox()(not_a_dict)
with pytest.raises(ValueError, match="must be a tensor"): with pytest.raises(ValueError, match="must be a tensor"):
not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()} not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()}
transforms.SanitizeBoundingBox()(not_a_tensor) transforms.SanitizeBoundingBox()(not_a_tensor)
......
...@@ -17,6 +17,7 @@ from common_utils import ( ...@@ -17,6 +17,7 @@ from common_utils import (
assert_no_warnings, assert_no_warnings,
cache, cache,
cpu_and_cuda, cpu_and_cuda,
freeze_rng_state,
ignore_jit_no_profile_information_warning, ignore_jit_no_profile_information_warning,
make_bounding_box, make_bounding_box,
make_detection_mask, make_detection_mask,
...@@ -25,12 +26,14 @@ from common_utils import ( ...@@ -25,12 +26,14 @@ from common_utils import (
make_image_tensor, make_image_tensor,
make_segmentation_mask, make_segmentation_mask,
make_video, make_video,
needs_cuda,
set_rng_seed, set_rng_seed,
) )
from torch import nn from torch import nn
from torch.testing import assert_close from torch.testing import assert_close
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms._functional_tensor import _max_value as get_max_value
...@@ -61,8 +64,10 @@ def _check_kernel_cuda_vs_cpu(kernel, input, *args, rtol, atol, **kwargs): ...@@ -61,8 +64,10 @@ def _check_kernel_cuda_vs_cpu(kernel, input, *args, rtol, atol, **kwargs):
input_cuda = input.as_subclass(torch.Tensor) input_cuda = input.as_subclass(torch.Tensor)
input_cpu = input_cuda.to("cpu") input_cpu = input_cuda.to("cpu")
actual = kernel(input_cuda, *args, **kwargs) with freeze_rng_state():
expected = kernel(input_cpu, *args, **kwargs) actual = kernel(input_cuda, *args, **kwargs)
with freeze_rng_state():
expected = kernel(input_cpu, *args, **kwargs)
assert_close(actual, expected, check_device=False, rtol=rtol, atol=atol) assert_close(actual, expected, check_device=False, rtol=rtol, atol=atol)
...@@ -1892,3 +1897,142 @@ class TestToDtype: ...@@ -1892,3 +1897,142 @@ class TestToDtype:
assert out["inpt"].dtype == inpt_dtype assert out["inpt"].dtype == inpt_dtype
assert out["bbox"].dtype == bbox_dtype assert out["bbox"].dtype == bbox_dtype
assert out["mask"].dtype == mask_dtype assert out["mask"].dtype == mask_dtype
class TestCutMixMixUp:
class DummyDataset:
def __init__(self, size, num_classes):
self.size = size
self.num_classes = num_classes
assert size < num_classes
def __getitem__(self, idx):
img = torch.rand(3, 100, 100)
label = idx # This ensures all labels in a batch are unique and makes testing easier
return img, label
def __len__(self):
return self.size
@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
def test_supported_input_structure(self, T):
batch_size = 32
num_classes = 100
dataset = self.DummyDataset(size=batch_size, num_classes=num_classes)
cutmix_mixup = T(alpha=0.5, num_classes=num_classes)
dl = DataLoader(dataset, batch_size=batch_size)
# Input sanity checks
img, target = next(iter(dl))
input_img_size = img.shape[-3:]
assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor)
assert target.shape == (batch_size,)
def check_output(img, target):
assert img.shape == (batch_size, *input_img_size)
assert target.shape == (batch_size, num_classes)
torch.testing.assert_close(target.sum(axis=-1), torch.ones(batch_size))
num_non_zero_labels = (target != 0).sum(axis=-1)
assert (num_non_zero_labels == 2).all()
# After Dataloader, as unpacked input
img, target = next(iter(dl))
assert target.shape == (batch_size,)
img, target = cutmix_mixup(img, target)
check_output(img, target)
# After Dataloader, as packed input
packed_from_dl = next(iter(dl))
assert isinstance(packed_from_dl, list)
img, target = cutmix_mixup(packed_from_dl)
check_output(img, target)
# As collation function. We expect default_collate to be used by users.
def collate_fn_1(batch):
return cutmix_mixup(default_collate(batch))
def collate_fn_2(batch):
return cutmix_mixup(*default_collate(batch))
for collate_fn in (collate_fn_1, collate_fn_2):
dl = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
img, target = next(iter(dl))
check_output(img, target)
@needs_cuda
@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
def test_cpu_vs_gpu(self, T):
num_classes = 10
batch_size = 3
H, W = 12, 12
imgs = torch.rand(batch_size, 3, H, W)
labels = torch.randint(0, num_classes, (batch_size,))
cutmix_mixup = T(alpha=0.5, num_classes=num_classes)
_check_kernel_cuda_vs_cpu(cutmix_mixup, imgs, labels, rtol=None, atol=None)
@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
def test_error(self, T):
num_classes = 10
batch_size = 9
imgs = torch.rand(batch_size, 3, 12, 12)
cutmix_mixup = T(alpha=0.5, num_classes=num_classes)
for input_with_bad_type in (
F.to_pil_image(imgs[0]),
datapoints.Mask(torch.rand(12, 12)),
datapoints.BoundingBox(torch.rand(2, 4), format="XYXY", spatial_size=12),
):
with pytest.raises(ValueError, match="does not support PIL images, "):
cutmix_mixup(input_with_bad_type)
with pytest.raises(ValueError, match="Could not infer where the labels are"):
cutmix_mixup({"img": imgs, "Nothing_else": 3})
with pytest.raises(ValueError, match="labels tensor should be of shape"):
# Note: the error message isn't ideal, but that's because the label heuristic found the img as the label
# It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently
cutmix_mixup(imgs)
with pytest.raises(ValueError, match="When using the default labels_getter"):
cutmix_mixup(imgs, "not_a_tensor")
with pytest.raises(ValueError, match="labels tensor should be of shape"):
cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3)))
with pytest.raises(ValueError, match="Expected a batched input with 4 dims"):
cutmix_mixup(imgs[None, None], torch.randint(0, num_classes, size=(batch_size,)))
with pytest.raises(ValueError, match="does not match the batch size of the labels"):
cutmix_mixup(imgs, torch.randint(0, num_classes, size=(batch_size + 1,)))
with pytest.raises(ValueError, match="labels tensor should be of shape"):
# The purpose of this check is more about documenting the current
# behaviour of what happens on a Compose(), rather than actually
# asserting the expected behaviour. We may support Compose() in the
# future, e.g. for 2 consecutive CutMix?
labels = torch.randint(0, num_classes, size=(batch_size,))
transforms.Compose([cutmix_mixup, cutmix_mixup])(imgs, labels)
@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
@pytest.mark.parametrize("sample_type", (tuple, list, dict))
def test_labels_getter_default_heuristic(key, sample_type):
labels = torch.arange(10)
sample = {key: labels, "another_key": "whatever"}
if sample_type is not dict:
sample = sample_type((None, sample, "whatever_again"))
assert transforms._utils._find_labels_default_heuristic(sample) is labels
if key.lower() != "labels":
# If "labels" is in the dict (case-insensitive),
# it takes precedence over other keys which would otherwise be a match
d = {key: "something_else", "labels": labels}
assert transforms._utils._find_labels_default_heuristic(d) is labels
...@@ -4,7 +4,7 @@ from . import functional, utils # usort: skip ...@@ -4,7 +4,7 @@ from . import functional, utils # usort: skip
from ._transform import Transform # usort: skip from ._transform import Transform # usort: skip
from ._augment import RandomErasing from ._augment import Cutmix, Mixup, RandomErasing
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import ( from ._color import (
ColorJitter, ColorJitter,
......
...@@ -5,11 +5,14 @@ from typing import Any, Dict, List, Tuple, Union ...@@ -5,11 +5,14 @@ from typing import Any, Dict, List, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
from torch.nn.functional import one_hot
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints, transforms as _transforms from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform, Transform
from .utils import is_simple_tensor, query_chw from ._utils import _parse_labels_getter
from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size
class RandomErasing(_RandomApplyTransform): class RandomErasing(_RandomApplyTransform):
...@@ -135,3 +138,171 @@ class RandomErasing(_RandomApplyTransform): ...@@ -135,3 +138,171 @@ class RandomErasing(_RandomApplyTransform):
inpt = F.erase(inpt, **params, inplace=self.inplace) inpt = F.erase(inpt, **params, inplace=self.inplace)
return inpt return inpt
class _BaseMixupCutmix(Transform):
def __init__(self, *, alpha: float = 1, num_classes: int, labels_getter="default") -> None:
super().__init__()
self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
self.num_classes = num_classes
self._labels_getter = _parse_labels_getter(labels_getter)
def forward(self, *inputs):
inputs = inputs if len(inputs) > 1 else inputs[0]
flat_inputs, spec = tree_flatten(inputs)
needs_transform_list = self._needs_transform_list(flat_inputs)
if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBox, datapoints.Mask):
raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.")
labels = self._labels_getter(inputs)
if not isinstance(labels, torch.Tensor):
raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.")
elif labels.ndim != 1:
raise ValueError(
f"labels tensor should be of shape (batch_size,) " f"but got shape {labels.shape} instead."
)
params = {
"labels": labels,
"batch_size": labels.shape[0],
**self._get_params(
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
),
}
# By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor coming
# after an image or video. However, we need to handle them in _transform, so we make sure to set them to True
needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True
flat_outputs = [
self._transform(inpt, params) if needs_transform else inpt
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
]
return tree_unflatten(flat_outputs, spec)
def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int):
expected_num_dims = 5 if isinstance(inpt, datapoints.Video) else 4
if inpt.ndim != expected_num_dims:
raise ValueError(
f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead."
)
if inpt.shape[0] != batch_size:
raise ValueError(
f"The batch size of the image or video does not match the batch size of the labels: "
f"{inpt.shape[0]} != {batch_size}."
)
def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
label = one_hot(label, num_classes=self.num_classes)
if not label.dtype.is_floating_point:
label = label.float()
return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))
class Mixup(_BaseMixupCutmix):
"""[BETA] Apply Mixup to the provided batch of images and labels.
.. v2betastatus:: Mixup transform
Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.
See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
into a tensor of shape ``(batch_size, num_classes)``.
Args:
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter a the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``Mixup()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels.
"""
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type]
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
lam = params["lam"]
if inpt is params["labels"]:
return self._mixup_label(inpt, lam=lam)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
return output
else:
return inpt
class Cutmix(_BaseMixupCutmix):
"""[BETA] Apply Cutmix to the provided batch of images and labels.
.. v2betastatus:: Cutmix transform
Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
<https://arxiv.org/abs/1905.04899>`_.
See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
into a tensor of shape ``(batch_size, num_classes)``.
Args:
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter a the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``Cutmix()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels.
"""
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
lam = float(self._dist.sample(())) # type: ignore[arg-type]
H, W = query_spatial_size(flat_inputs)
r_x = torch.randint(W, size=(1,))
r_y = torch.randint(H, size=(1,))
r = 0.5 * math.sqrt(1.0 - lam)
r_w_half = int(r * W)
r_h_half = int(r * H)
x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))
box = (x1, y1, x2, y2)
lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if inpt is params["labels"]:
return self._mixup_label(inpt, lam=params["lam_adjusted"])
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])
x1, y1, x2, y2 = params["box"]
rolled = inpt.roll(1, 0)
output = inpt.clone()
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output
else:
return inpt
import collections
import warnings import warnings
from contextlib import suppress from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Sequence, Type, Union
import PIL.Image import PIL.Image
...@@ -11,7 +9,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten ...@@ -11,7 +9,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints, transforms as _transforms from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from ._utils import _setup_float_or_seq, _setup_size from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size
from .utils import has_any, is_simple_tensor, query_bounding_box from .utils import has_any, is_simple_tensor, query_bounding_box
...@@ -318,12 +316,11 @@ class SanitizeBoundingBox(Transform): ...@@ -318,12 +316,11 @@ class SanitizeBoundingBox(Transform):
Args: Args:
min_size (float, optional) The size below which bounding boxes are removed. Default is 1. min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input. labels_getter (callable or str or None, optional): indicates how to identify the labels in the input.
It can be a str in which case the input is expected to be a dict, and ``labels_getter`` then specifies By default, this will try to find a "labels" key in the input (case-insensitive), if
the key whose value corresponds to the labels. It can also be a callable that takes the same input
as the transform, and returns the labels.
By default, this will try to find a "labels" key in the input, if
the input is a dict or it is a tuple whose second element is a dict. the input is a dict or it is a tuple whose second element is a dict.
This heuristic should work well with a lot of datasets, including the built-in torchvision datasets. This heuristic should work well with a lot of datasets, including the built-in torchvision datasets.
It can also be a callable that takes the same input
as the transform, and returns the labels.
""" """
def __init__( def __init__(
...@@ -338,66 +335,16 @@ class SanitizeBoundingBox(Transform): ...@@ -338,66 +335,16 @@ class SanitizeBoundingBox(Transform):
self.min_size = min_size self.min_size = min_size
self.labels_getter = labels_getter self.labels_getter = labels_getter
self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]] self._labels_getter = _parse_labels_getter(labels_getter)
if labels_getter == "default":
self._labels_getter = self._find_labels_default_heuristic
elif callable(labels_getter):
self._labels_getter = labels_getter
elif isinstance(labels_getter, str):
self._labels_getter = lambda inputs: SanitizeBoundingBox._get_dict_or_second_tuple_entry(inputs)[
labels_getter # type: ignore[index]
]
elif labels_getter is None:
self._labels_getter = None
else:
raise ValueError(
"labels_getter should either be a str, callable, or 'default'. "
f"Got {labels_getter} of type {type(labels_getter)}."
)
@staticmethod
def _get_dict_or_second_tuple_entry(inputs: Any) -> Mapping[str, Any]:
# datasets outputs may be plain dicts like {"img": ..., "labels": ..., "bbox": ...}
# or tuples like (img, {"labels":..., "bbox": ...})
# This hacky helper accounts for both structures.
if isinstance(inputs, tuple):
inputs = inputs[1]
if not isinstance(inputs, collections.abc.Mapping):
raise ValueError(
f"If labels_getter is a str or 'default', "
f"then the input to forward() must be a dict or a tuple whose second element is a dict."
f" Got {type(inputs)} instead."
)
return inputs
@staticmethod
def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
# Tries to find a "labels" key, otherwise tries for the first key that contains "label" - case insensitive
# Returns None if nothing is found
inputs = SanitizeBoundingBox._get_dict_or_second_tuple_entry(inputs)
candidate_key = None
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
if candidate_key is None:
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if "label" in key.lower())
if candidate_key is None:
raise ValueError(
"Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?"
"If there are no samples and it is by design, pass labels_getter=None."
)
return inputs[candidate_key]
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
inputs = inputs if len(inputs) > 1 else inputs[0] inputs = inputs if len(inputs) > 1 else inputs[0]
if self._labels_getter is None: labels = self._labels_getter(inputs)
labels = None if labels is not None and not isinstance(labels, torch.Tensor):
else: raise ValueError(
labels = self._labels_getter(inputs) f"The labels in the input to forward() must be a tensor or None, got {type(labels)} instead."
if labels is not None and not isinstance(labels, torch.Tensor): )
raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.")
flat_inputs, spec = tree_flatten(inputs) flat_inputs, spec = tree_flatten(inputs)
# TODO: this enforces one single BoundingBox entry. # TODO: this enforces one single BoundingBox entry.
......
import collections.abc
import functools import functools
import numbers import numbers
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Literal, Sequence, Type, TypeVar, Union from contextlib import suppress
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, TypeVar, Union
import torch
from torchvision import datapoints from torchvision import datapoints
from torchvision.datapoints._datapoint import _FillType, _FillTypeJIT from torchvision.datapoints._datapoint import _FillType, _FillTypeJIT
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
...@@ -93,3 +96,60 @@ def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: ...@@ -93,3 +96,60 @@ def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None: def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
"""
This heuristic covers three cases:
1. The input is tuple or list whose second item is a labels tensor. This happens for already batched
classification inputs for Mixup and Cutmix (typically after the Dataloder).
2. The input is a tuple or list whose second item is a dictionary that contains the labels tensor
under a label-like (see below) key. This happens for the inputs of detection models.
3. The input is a dictionary that is structured as the one from 2.
What is "label-like" key? We first search for an case-insensitive match of 'labels' inside the keys of the
dictionary. This is the name our detection models expect. If we can't find that, we look for a case-insensitive
match of the term 'label' anywhere inside the key, i.e. 'FooLaBeLBar'. If we can't find that either, the dictionary
contains no "label-like" key.
"""
if isinstance(inputs, (tuple, list)):
inputs = inputs[1]
# Mixup, Cutmix
if isinstance(inputs, torch.Tensor):
return inputs
if not isinstance(inputs, collections.abc.Mapping):
raise ValueError(
f"When using the default labels_getter, the input passed to forward must be a dictionary or a two-tuple "
f"whose second item is a dictionary or a tensor, but got {inputs} instead."
)
candidate_key = None
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
if candidate_key is None:
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if "label" in key.lower())
if candidate_key is None:
raise ValueError(
"Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?"
"If there are no labels in the sample by design, pass labels_getter=None."
)
return inputs[candidate_key]
def _parse_labels_getter(
labels_getter: Union[str, Callable[[Any], Optional[torch.Tensor]], None]
) -> Callable[[Any], Optional[torch.Tensor]]:
if labels_getter == "default":
return _find_labels_default_heuristic
elif callable(labels_getter):
return labels_getter
elif labels_getter is None:
return lambda _: None
else:
raise ValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment