import math import numbers import warnings from typing import Any, Dict, List, Tuple, Union import PIL.Image import torch from torch.nn.functional import one_hot from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import datapoints, transforms as _transforms from torchvision.transforms.v2 import functional as F from ._transform import _RandomApplyTransform, Transform from ._utils import _parse_labels_getter from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size class RandomErasing(_RandomApplyTransform): """[BETA] Randomly select a rectangle region in the input image or video and erase its pixels. .. v2betastatus:: RandomErasing transform This transform does not support PIL Image. 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896 Args: p (float, optional): probability that the random erasing operation will be performed. scale (tuple of float, optional): range of proportion of erased area against input image. ratio (tuple of float, optional): range of aspect ratio of erased area. value (number or tuple of numbers): erasing value. Default is 0. If a single int, it is used to erase all pixels. If a tuple of length 3, it is used to erase R, G, B channels respectively. If a str of 'random', erasing each pixel with random values. inplace (bool, optional): boolean to make this transform inplace. Default set to False. Returns: Erased input. Example: >>> from torchvision.transforms import v2 as transforms >>> >>> transform = transforms.Compose([ >>> transforms.RandomHorizontalFlip(), >>> transforms.PILToTensor(), >>> transforms.ConvertImageDtype(torch.float), >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), >>> transforms.RandomErasing(), >>> ]) """ _v1_transform_cls = _transforms.RandomErasing def _extract_params_for_v1_transform(self) -> Dict[str, Any]: return dict( super()._extract_params_for_v1_transform(), value="random" if self.value is None else self.value, ) _transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video) def __init__( self, p: float = 0.5, scale: Tuple[float, float] = (0.02, 0.33), ratio: Tuple[float, float] = (0.3, 3.3), value: float = 0.0, inplace: bool = False, ): super().__init__(p=p) if not isinstance(value, (numbers.Number, str, tuple, list)): raise TypeError("Argument value should be either a number or str or a sequence") if isinstance(value, str) and value != "random": raise ValueError("If value is str, it should be 'random'") if not isinstance(scale, (tuple, list)): raise TypeError("Scale should be a sequence") if not isinstance(ratio, (tuple, list)): raise TypeError("Ratio should be a sequence") if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): warnings.warn("Scale and ratio should be of kind (min, max)") if scale[0] < 0 or scale[1] > 1: raise ValueError("Scale should be between 0 and 1") self.scale = scale self.ratio = ratio if isinstance(value, (int, float)): self.value = [float(value)] elif isinstance(value, str): self.value = None elif isinstance(value, (list, tuple)): self.value = [float(v) for v in value] else: self.value = value self.inplace = inplace self._log_ratio = torch.log(torch.tensor(self.ratio)) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: img_c, img_h, img_w = query_chw(flat_inputs) if self.value is not None and not (len(self.value) in (1, img_c)): raise ValueError( f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)" ) area = img_h * img_w log_ratio = self._log_ratio for _ in range(10): erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() aspect_ratio = torch.exp( torch.empty(1).uniform_( log_ratio[0], # type: ignore[arg-type] log_ratio[1], # type: ignore[arg-type] ) ).item() h = int(round(math.sqrt(erase_area * aspect_ratio))) w = int(round(math.sqrt(erase_area / aspect_ratio))) if not (h < img_h and w < img_w): continue if self.value is None: v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() else: v = torch.tensor(self.value)[:, None, None] i = torch.randint(0, img_h - h + 1, size=(1,)).item() j = torch.randint(0, img_w - w + 1, size=(1,)).item() break else: i, j, h, w, v = 0, 0, img_h, img_w, None return dict(i=i, j=j, h=h, w=w, v=v) def _transform( self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] ) -> Union[datapoints._ImageType, datapoints._VideoType]: if params["v"] is not None: inpt = F.erase(inpt, **params, inplace=self.inplace) return inpt class _BaseMixupCutmix(Transform): def __init__(self, *, alpha: float = 1, num_classes: int, labels_getter="default") -> None: super().__init__() self.alpha = alpha self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) self.num_classes = num_classes self._labels_getter = _parse_labels_getter(labels_getter) def forward(self, *inputs): inputs = inputs if len(inputs) > 1 else inputs[0] flat_inputs, spec = tree_flatten(inputs) needs_transform_list = self._needs_transform_list(flat_inputs) if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBox, datapoints.Mask): raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.") labels = self._labels_getter(inputs) if not isinstance(labels, torch.Tensor): raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.") elif labels.ndim != 1: raise ValueError( f"labels tensor should be of shape (batch_size,) " f"but got shape {labels.shape} instead." ) params = { "labels": labels, "batch_size": labels.shape[0], **self._get_params( [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] ), } # By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor coming # after an image or video. However, we need to handle them in _transform, so we make sure to set them to True needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True flat_outputs = [ self._transform(inpt, params) if needs_transform else inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) ] return tree_unflatten(flat_outputs, spec) def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int): expected_num_dims = 5 if isinstance(inpt, datapoints.Video) else 4 if inpt.ndim != expected_num_dims: raise ValueError( f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead." ) if inpt.shape[0] != batch_size: raise ValueError( f"The batch size of the image or video does not match the batch size of the labels: " f"{inpt.shape[0]} != {batch_size}." ) def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor: label = one_hot(label, num_classes=self.num_classes) if not label.dtype.is_floating_point: label = label.float() return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam)) class Mixup(_BaseMixupCutmix): """[BETA] Apply Mixup to the provided batch of images and labels. .. v2betastatus:: Mixup transform Paper: `mixup: Beyond Empirical Risk Minimization `_. See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples. In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed into a tensor of shape ``(batch_size, num_classes)``. Args: alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. num_classes (int): number of classes in the batch. Used for one-hot-encoding. labels_getter (callable or "default", optional): indicates how to identify the labels in the input. By default, this will pick the second parameter a the labels if it's a tensor. This covers the most common scenario where this transform is called as ``Mixup()(imgs_batch, labels_batch)``. It can also be a callable that takes the same input as the transform, and returns the labels. """ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type] def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: lam = params["lam"] if inpt is params["labels"]: return self._mixup_label(inpt, lam=lam) elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): self._check_image_or_video(inpt, batch_size=params["batch_size"]) output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) if isinstance(inpt, (datapoints.Image, datapoints.Video)): output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] return output else: return inpt class Cutmix(_BaseMixupCutmix): """[BETA] Apply Cutmix to the provided batch of images and labels. .. v2betastatus:: Cutmix transform Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features `_. See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples. In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed into a tensor of shape ``(batch_size, num_classes)``. Args: alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. num_classes (int): number of classes in the batch. Used for one-hot-encoding. labels_getter (callable or "default", optional): indicates how to identify the labels in the input. By default, this will pick the second parameter a the labels if it's a tensor. This covers the most common scenario where this transform is called as ``Cutmix()(imgs_batch, labels_batch)``. It can also be a callable that takes the same input as the transform, and returns the labels. """ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: lam = float(self._dist.sample(())) # type: ignore[arg-type] H, W = query_spatial_size(flat_inputs) r_x = torch.randint(W, size=(1,)) r_y = torch.randint(H, size=(1,)) r = 0.5 * math.sqrt(1.0 - lam) r_w_half = int(r * W) r_h_half = int(r * H) x1 = int(torch.clamp(r_x - r_w_half, min=0)) y1 = int(torch.clamp(r_y - r_h_half, min=0)) x2 = int(torch.clamp(r_x + r_w_half, max=W)) y2 = int(torch.clamp(r_y + r_h_half, max=H)) box = (x1, y1, x2, y2) lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) return dict(box=box, lam_adjusted=lam_adjusted) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if inpt is params["labels"]: return self._mixup_label(inpt, lam=params["lam_adjusted"]) elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): self._check_image_or_video(inpt, batch_size=params["batch_size"]) x1, y1, x2, y2 = params["box"] rolled = inpt.roll(1, 0) output = inpt.clone() output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2] if isinstance(inpt, (datapoints.Image, datapoints.Video)): output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] return output else: return inpt