_augment.py 15.8 KB
Newer Older
1
2
3
import math
import numbers
import warnings
4
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
5
6
7

import PIL.Image
import torch
8
9
from torch.nn.functional import one_hot
from torch.utils._pytree import tree_flatten, tree_unflatten
10
from torchvision import transforms as _transforms, tv_tensors
11
12
from torchvision.transforms.v2 import functional as F

13
from ._transform import _RandomApplyTransform, Transform
Thien Tran's avatar
Thien Tran committed
14
from ._utils import _check_sequence_input, _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size
15
16
17


class RandomErasing(_RandomApplyTransform):
18
    """Randomly select a rectangle region in the input image or video and erase its pixels.
19
20
21
22
23

    This transform does not support PIL Image.
    'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896

    Args:
24
25
26
27
        p (float, optional): probability that the random erasing operation will be performed.
        scale (tuple of float, optional): range of proportion of erased area against input image.
        ratio (tuple of float, optional): range of aspect ratio of erased area.
        value (number or tuple of numbers): erasing value. Default is 0. If a single int, it is used to
28
29
30
            erase all pixels. If a tuple of length 3, it is used to erase
            R, G, B channels respectively.
            If a str of 'random', erasing each pixel with random values.
31
        inplace (bool, optional): boolean to make this transform inplace. Default set to False.
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

    Returns:
        Erased input.

    Example:
        >>> from torchvision.transforms import v2 as transforms
        >>>
        >>> transform = transforms.Compose([
        >>>   transforms.RandomHorizontalFlip(),
        >>>   transforms.PILToTensor(),
        >>>   transforms.ConvertImageDtype(torch.float),
        >>>   transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>>   transforms.RandomErasing(),
        >>> ])
    """

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    _v1_transform_cls = _transforms.RandomErasing

    def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
        return dict(
            super()._extract_params_for_v1_transform(),
            value="random" if self.value is None else self.value,
        )

    def __init__(
        self,
        p: float = 0.5,
        scale: Tuple[float, float] = (0.02, 0.33),
        ratio: Tuple[float, float] = (0.3, 3.3),
        value: float = 0.0,
        inplace: bool = False,
    ):
        super().__init__(p=p)
        if not isinstance(value, (numbers.Number, str, tuple, list)):
            raise TypeError("Argument value should be either a number or str or a sequence")
        if isinstance(value, str) and value != "random":
            raise ValueError("If value is str, it should be 'random'")
        if not isinstance(scale, (tuple, list)):
            raise TypeError("Scale should be a sequence")
        if not isinstance(ratio, (tuple, list)):
            raise TypeError("Ratio should be a sequence")
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
            warnings.warn("Scale and ratio should be of kind (min, max)")
        if scale[0] < 0 or scale[1] > 1:
            raise ValueError("Scale should be between 0 and 1")
        self.scale = scale
        self.ratio = ratio
        if isinstance(value, (int, float)):
            self.value = [float(value)]
        elif isinstance(value, str):
            self.value = None
        elif isinstance(value, (list, tuple)):
            self.value = [float(v) for v in value]
        else:
            self.value = value
        self.inplace = inplace

        self._log_ratio = torch.log(torch.tensor(self.ratio))

Nicolas Hug's avatar
Nicolas Hug committed
91
    def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
92
        if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
93
94
            warnings.warn(
                f"{type(self).__name__}() is currently passing through inputs of type "
95
                f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
96
            )
Nicolas Hug's avatar
Nicolas Hug committed
97
        return super()._call_kernel(functional, inpt, *args, **kwargs)
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
        img_c, img_h, img_w = query_chw(flat_inputs)

        if self.value is not None and not (len(self.value) in (1, img_c)):
            raise ValueError(
                f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
            )

        area = img_h * img_w

        log_ratio = self._log_ratio
        for _ in range(10):
            erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
            aspect_ratio = torch.exp(
                torch.empty(1).uniform_(
                    log_ratio[0],  # type: ignore[arg-type]
                    log_ratio[1],  # type: ignore[arg-type]
                )
            ).item()

            h = int(round(math.sqrt(erase_area * aspect_ratio)))
            w = int(round(math.sqrt(erase_area / aspect_ratio)))
            if not (h < img_h and w < img_w):
                continue

            if self.value is None:
                v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
            else:
                v = torch.tensor(self.value)[:, None, None]

            i = torch.randint(0, img_h - h + 1, size=(1,)).item()
            j = torch.randint(0, img_w - w + 1, size=(1,)).item()
            break
        else:
            i, j, h, w, v = 0, 0, img_h, img_w, None

        return dict(i=i, j=j, h=h, w=w, v=v)

137
    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
138
        if params["v"] is not None:
139
            inpt = self._call_kernel(F.erase, inpt, **params, inplace=self.inplace)
140
141

        return inpt
142
143


Nicolas Hug's avatar
Nicolas Hug committed
144
class _BaseMixUpCutMix(Transform):
145
    def __init__(self, *, alpha: float = 1.0, num_classes: Optional[int] = None, labels_getter="default") -> None:
146
        super().__init__()
147
        self.alpha = float(alpha)
148
149
150
151
152
153
154
155
156
157
158
        self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))

        self.num_classes = num_classes

        self._labels_getter = _parse_labels_getter(labels_getter)

    def forward(self, *inputs):
        inputs = inputs if len(inputs) > 1 else inputs[0]
        flat_inputs, spec = tree_flatten(inputs)
        needs_transform_list = self._needs_transform_list(flat_inputs)

159
        if has_any(flat_inputs, PIL.Image.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask):
160
161
162
163
164
            raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.")

        labels = self._labels_getter(inputs)
        if not isinstance(labels, torch.Tensor):
            raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.")
165
        if labels.ndim not in (1, 2):
166
            raise ValueError(
167
168
169
                f"labels should be index based with shape (batch_size,) "
                f"or probability based with shape (batch_size, num_classes), "
                f"but got a tensor of shape {labels.shape} instead."
170
            )
171
172
173
174
175
176
177
178
179
        if labels.ndim == 2 and self.num_classes is not None and labels.shape[-1] != self.num_classes:
            raise ValueError(
                f"When passing 2D labels, "
                f"the number of elements in last dimension must match num_classes: "
                f"{labels.shape[-1]} != {self.num_classes}. "
                f"You can Leave num_classes to None."
            )
        if labels.ndim == 1 and self.num_classes is None:
            raise ValueError("num_classes must be passed if the labels are index-based (1D)")
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

        params = {
            "labels": labels,
            "batch_size": labels.shape[0],
            **self._get_params(
                [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
            ),
        }

        # By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor coming
        # after an image or video. However, we need to handle them in _transform, so we make sure to set them to True
        needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True
        flat_outputs = [
            self._transform(inpt, params) if needs_transform else inpt
            for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
        ]

        return tree_unflatten(flat_outputs, spec)

    def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int):
200
        expected_num_dims = 5 if isinstance(inpt, tv_tensors.Video) else 4
201
202
203
204
205
206
207
208
209
210
211
        if inpt.ndim != expected_num_dims:
            raise ValueError(
                f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead."
            )
        if inpt.shape[0] != batch_size:
            raise ValueError(
                f"The batch size of the image or video does not match the batch size of the labels: "
                f"{inpt.shape[0]} != {batch_size}."
            )

    def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
212
213
        if label.ndim == 1:
            label = one_hot(label, num_classes=self.num_classes)  # type: ignore[arg-type]
214
215
216
217
218
        if not label.dtype.is_floating_point:
            label = label.float()
        return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))


Nicolas Hug's avatar
Nicolas Hug committed
219
class MixUp(_BaseMixUpCutMix):
220
    """Apply MixUp to the provided batch of images and labels.
221
222
223

    Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.

224
225
226
    .. note::
        This transform is meant to be used on **batches** of samples, not
        individual images. See
Nicolas Hug's avatar
Nicolas Hug committed
227
        :ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage
228
229
230
231
        examples.
        The sample pairing is deterministic and done by matching consecutive
        samples in the batch, so the batch needs to be shuffled (this is an
        implementation detail, not a guaranteed convention.)
232
233
234
235
236
237

    In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
    into a tensor of shape ``(batch_size, num_classes)``.

    Args:
        alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
238
239
        num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
            Can be None only if the labels are already one-hot-encoded.
240
        labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
David Chiu's avatar
David Chiu committed
241
            By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
Nicolas Hug's avatar
Nicolas Hug committed
242
            common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
243
244
245
246
247
248
249
250
251
252
253
            It can also be a callable that takes the same input as the transform, and returns the labels.
    """

    def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
        return dict(lam=float(self._dist.sample(())))  # type: ignore[arg-type]

    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
        lam = params["lam"]

        if inpt is params["labels"]:
            return self._mixup_label(inpt, lam=lam)
254
        elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
255
256
257
258
            self._check_image_or_video(inpt, batch_size=params["batch_size"])

            output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))

259
260
            if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
                output = tv_tensors.wrap(output, like=inpt)
261
262
263
264
265
266

            return output
        else:
            return inpt


Nicolas Hug's avatar
Nicolas Hug committed
267
class CutMix(_BaseMixUpCutMix):
268
    """Apply CutMix to the provided batch of images and labels.
269
270
271
272

    Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
    <https://arxiv.org/abs/1905.04899>`_.

273
274
275
    .. note::
        This transform is meant to be used on **batches** of samples, not
        individual images. See
Nicolas Hug's avatar
Nicolas Hug committed
276
        :ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage
277
278
279
280
        examples.
        The sample pairing is deterministic and done by matching consecutive
        samples in the batch, so the batch needs to be shuffled (this is an
        implementation detail, not a guaranteed convention.)
281
282
283
284
285
286

    In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
    into a tensor of shape ``(batch_size, num_classes)``.

    Args:
        alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
287
288
        num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
            Can be None only if the labels are already one-hot-encoded.
289
        labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
David Chiu's avatar
David Chiu committed
290
            By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
Nicolas Hug's avatar
Nicolas Hug committed
291
            common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
292
293
294
295
296
297
            It can also be a callable that takes the same input as the transform, and returns the labels.
    """

    def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
        lam = float(self._dist.sample(()))  # type: ignore[arg-type]

Philip Meier's avatar
Philip Meier committed
298
        H, W = query_size(flat_inputs)
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319

        r_x = torch.randint(W, size=(1,))
        r_y = torch.randint(H, size=(1,))

        r = 0.5 * math.sqrt(1.0 - lam)
        r_w_half = int(r * W)
        r_h_half = int(r * H)

        x1 = int(torch.clamp(r_x - r_w_half, min=0))
        y1 = int(torch.clamp(r_y - r_h_half, min=0))
        x2 = int(torch.clamp(r_x + r_w_half, max=W))
        y2 = int(torch.clamp(r_y + r_h_half, max=H))
        box = (x1, y1, x2, y2)

        lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

        return dict(box=box, lam_adjusted=lam_adjusted)

    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
        if inpt is params["labels"]:
            return self._mixup_label(inpt, lam=params["lam_adjusted"])
320
        elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
321
322
323
324
325
326
327
            self._check_image_or_video(inpt, batch_size=params["batch_size"])

            x1, y1, x2, y2 = params["box"]
            rolled = inpt.roll(1, 0)
            output = inpt.clone()
            output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]

328
329
            if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
                output = tv_tensors.wrap(output, like=inpt)
330
331
332
333

            return output
        else:
            return inpt
Thien Tran's avatar
Thien Tran committed
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369


class JPEG(Transform):
    """Apply JPEG compression and decompression to the given images.

    If the input is a :class:`torch.Tensor`, it is expected
    to be of dtype uint8, on CPU, and have [..., 3 or 1, H, W] shape,
    where ... means an arbitrary number of leading dimensions.

    Args:
        quality (sequence or number): JPEG quality, from 1 to 100. Lower means more compression.
            If quality is a sequence like (min, max), it specifies the range of JPEG quality to
            randomly select from (inclusive of both ends).

    Returns:
        image with JPEG compression.
    """

    def __init__(self, quality: Union[int, Sequence[int]]):
        super().__init__()
        if isinstance(quality, int):
            quality = [quality, quality]
        else:
            _check_sequence_input(quality, "quality", req_sizes=(2,))

        if not (1 <= quality[0] <= quality[1] <= 100 and isinstance(quality[0], int) and isinstance(quality[1], int)):
            raise ValueError(f"quality must be an integer from 1 to 100, got {quality =}")

        self.quality = quality

    def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
        quality = torch.randint(self.quality[0], self.quality[1] + 1, ()).item()
        return dict(quality=quality)

    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
        return self._call_kernel(F.jpeg, inpt, quality=params["quality"])