_auto_augment.py 31.3 KB
Newer Older
1
import math
Nicolas Hug's avatar
Nicolas Hug committed
2
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
3
4
5

import PIL.Image
import torch
6

7
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
8
from torchvision import transforms as _transforms, tv_tensors
9
from torchvision.transforms import _functional_tensor as _FT
10
11
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
Philip Meier's avatar
Philip Meier committed
12
from torchvision.transforms.v2.functional._meta import get_size
13
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
14

Nicolas Hug's avatar
Nicolas Hug committed
15
from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor
16
17


18
ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, tv_tensors.Image, tv_tensors.Video]
19
20


21
22
class _AutoAugmentBase(Transform):
    def __init__(
23
24
        self,
        *,
25
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
26
        fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
27
28
    ) -> None:
        super().__init__()
29
        self.interpolation = _check_interpolation(interpolation)
30
31
32
33
34
35
        self.fill = fill
        self._fill = _setup_fill_arg(fill)

    def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
        params = super()._extract_params_for_v1_transform()

36
37
        if isinstance(params["fill"], dict):
            raise ValueError(f"{type(self).__name__}() can not be scripted for when `fill` is a dictionary.")
38
39

        return params
40

41
    def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
42
43
44
45
        keys = tuple(dct.keys())
        key = keys[int(torch.randint(len(keys), ()))]
        return key, dct[key]

46
    def _flatten_and_extract_image_or_video(
47
        self,
48
        inputs: Any,
49
        unsupported_types: Tuple[Type, ...] = (tv_tensors.BoundingBoxes, tv_tensors.Mask),
50
    ) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]:
51
        flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
52
        needs_transform_list = self._needs_transform_list(flat_inputs)
53

54
        image_or_videos = []
55
56
        for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)):
            if needs_transform and check_type(
57
58
                inpt,
                (
59
                    tv_tensors.Image,
60
                    PIL.Image.Image,
61
                    is_pure_tensor,
62
                    tv_tensors.Video,
63
64
                ),
            ):
65
                image_or_videos.append((idx, inpt))
66
67
68
            elif isinstance(inpt, unsupported_types):
                raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")

69
        if not image_or_videos:
70
            raise TypeError("Found no image in the sample.")
71
        if len(image_or_videos) > 1:
72
            raise TypeError(
73
74
                f"Auto augment transformations are only properly defined for a single image or video, "
                f"but found {len(image_or_videos)}."
75
76
            )

77
78
79
80
81
82
        idx, image_or_video = image_or_videos[0]
        return (flat_inputs, spec, idx), image_or_video

    def _unflatten_and_insert_image_or_video(
        self,
        flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
83
        image_or_video: ImageOrVideo,
84
85
86
87
    ) -> Any:
        flat_inputs, spec, idx = flat_inputs_with_spec
        flat_inputs[idx] = image_or_video
        return tree_unflatten(flat_inputs, spec)
88

89
    def _apply_image_or_video_transform(
90
        self,
91
        image: ImageOrVideo,
92
93
        transform_id: str,
        magnitude: float,
94
        interpolation: Union[InterpolationMode, int],
95
96
        fill: Dict[Union[Type, str], _FillTypeJIT],
    ) -> ImageOrVideo:
Nicolas Hug's avatar
Nicolas Hug committed
97
98
        # Note: this cast is wrong and is only here to make mypy happy (it disagrees with torchscript)
        image = cast(torch.Tensor, image)
99
        fill_ = _get_fill(fill, type(image))
100

101
102
103
        if transform_id == "Identity":
            return image
        elif transform_id == "ShearX":
104
105
106
107
108
109
            # magnitude should be arctan(magnitude)
            # official autoaug: (1, level, 0, 0, 1, 0)
            # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
            # compared to
            # torchvision:      (1, tan(level), 0, 0, 1, 0)
            # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
110
            return F.affine(
111
112
113
114
                image,
                angle=0.0,
                translate=[0, 0],
                scale=1.0,
115
                shear=[math.degrees(math.atan(magnitude)), 0.0],
116
                interpolation=interpolation,
117
118
                fill=fill_,
                center=[0, 0],
119
120
            )
        elif transform_id == "ShearY":
121
122
            # magnitude should be arctan(magnitude)
            # See above
123
            return F.affine(
124
125
126
127
                image,
                angle=0.0,
                translate=[0, 0],
                scale=1.0,
128
                shear=[0.0, math.degrees(math.atan(magnitude))],
129
                interpolation=interpolation,
130
131
                fill=fill_,
                center=[0, 0],
132
133
            )
        elif transform_id == "TranslateX":
134
            return F.affine(
135
136
137
138
139
                image,
                angle=0.0,
                translate=[int(magnitude), 0],
                scale=1.0,
                interpolation=interpolation,
140
                shear=[0.0, 0.0],
141
                fill=fill_,
142
143
            )
        elif transform_id == "TranslateY":
144
            return F.affine(
145
146
147
148
149
                image,
                angle=0.0,
                translate=[0, int(magnitude)],
                scale=1.0,
                interpolation=interpolation,
150
                shear=[0.0, 0.0],
151
                fill=fill_,
152
153
            )
        elif transform_id == "Rotate":
154
            return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_)
155
        elif transform_id == "Brightness":
156
            return F.adjust_brightness(image, brightness_factor=1.0 + magnitude)
157
        elif transform_id == "Color":
158
            return F.adjust_saturation(image, saturation_factor=1.0 + magnitude)
159
        elif transform_id == "Contrast":
160
            return F.adjust_contrast(image, contrast_factor=1.0 + magnitude)
161
        elif transform_id == "Sharpness":
162
            return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude)
163
        elif transform_id == "Posterize":
164
            return F.posterize(image, bits=int(magnitude))
165
        elif transform_id == "Solarize":
Philip Meier's avatar
Philip Meier committed
166
            bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0
167
            return F.solarize(image, threshold=bound * magnitude)
168
        elif transform_id == "AutoContrast":
169
            return F.autocontrast(image)
170
        elif transform_id == "Equalize":
171
            return F.equalize(image)
172
        elif transform_id == "Invert":
173
            return F.invert(image)
174
175
        else:
            raise ValueError(f"No transform available for {transform_id}")
176
177
178


class AutoAugment(_AutoAugmentBase):
179
    r"""AutoAugment data augmentation method based on
180
181
    `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.

182
183
184
    This transformation works on images and videos only.

    If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
185
186
187
188
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".

    Args:
189
        policy (AutoAugmentPolicy, optional): Desired policy enum defined by
190
            :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
191
        interpolation (InterpolationMode, optional): Desired interpolation enum defined by
192
193
194
195
196
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
    """
197
198
    _v1_transform_cls = _transforms.AutoAugment

199
    _AUGMENTATION_SPACE = {
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "TranslateX": (
            lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
            True,
        ),
        "TranslateY": (
            lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
            True,
        ),
        "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
        "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
215
        "Posterize": (
216
            lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
217
218
            False,
        ),
219
        "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
220
221
222
        "AutoContrast": (lambda num_bins, height, width: None, False),
        "Equalize": (lambda num_bins, height, width: None, False),
        "Invert": (lambda num_bins, height, width: None, False),
223
224
    }

225
226
227
    def __init__(
        self,
        policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
228
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
229
        fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
230
231
    ) -> None:
        super().__init__(interpolation=interpolation, fill=fill)
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        self.policy = policy
        self._policies = self._get_policies(policy)

    def _get_policies(
        self, policy: AutoAugmentPolicy
    ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
        if policy == AutoAugmentPolicy.IMAGENET:
            return [
                (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
                (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
                (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
                (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
                (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
                (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
                (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
                (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
                (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
                (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
                (("Rotate", 0.8, 8), ("Color", 0.4, 0)),
                (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
                (("Equalize", 0.0, None), ("Equalize", 0.8, None)),
                (("Invert", 0.6, None), ("Equalize", 1.0, None)),
                (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
                (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
                (("Color", 0.8, 8), ("Solarize", 0.8, 7)),
                (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
                (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
                (("Color", 0.4, 0), ("Equalize", 0.6, None)),
                (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
                (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
                (("Invert", 0.6, None), ("Equalize", 1.0, None)),
                (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
                (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
            ]
        elif policy == AutoAugmentPolicy.CIFAR10:
            return [
                (("Invert", 0.1, None), ("Contrast", 0.2, 6)),
                (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
                (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
                (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
                (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
                (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
                (("Color", 0.4, 3), ("Brightness", 0.6, 7)),
                (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
                (("Equalize", 0.6, None), ("Equalize", 0.5, None)),
                (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
                (("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
                (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
                (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
                (("Brightness", 0.9, 6), ("Color", 0.2, 8)),
                (("Solarize", 0.5, 2), ("Invert", 0.0, None)),
                (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
                (("Equalize", 0.2, None), ("Equalize", 0.6, None)),
                (("Color", 0.9, 9), ("Equalize", 0.6, None)),
                (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
                (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
                (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
                (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
                (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
                (("Equalize", 0.8, None), ("Invert", 0.1, None)),
                (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
            ]
        elif policy == AutoAugmentPolicy.SVHN:
            return [
                (("ShearX", 0.9, 4), ("Invert", 0.2, None)),
                (("ShearY", 0.9, 8), ("Invert", 0.7, None)),
                (("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
                (("Invert", 0.9, None), ("Equalize", 0.6, None)),
                (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
                (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
                (("ShearY", 0.9, 8), ("Invert", 0.4, None)),
                (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
                (("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
                (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
                (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
                (("ShearY", 0.8, 8), ("Invert", 0.7, None)),
                (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
                (("Invert", 0.9, None), ("Equalize", 0.6, None)),
                (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
                (("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
                (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
                (("Invert", 0.6, None), ("Rotate", 0.8, 4)),
                (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
                (("ShearX", 0.1, 6), ("Invert", 0.6, None)),
                (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
                (("ShearY", 0.8, 4), ("Invert", 0.8, None)),
                (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
                (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
                (("ShearX", 0.7, 2), ("Invert", 0.1, None)),
            ]
        else:
            raise ValueError(f"The provided policy {policy} is not recognized.")

325
    def forward(self, *inputs: Any) -> Any:
326
        flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
Nicolas Hug's avatar
Nicolas Hug committed
327
        height, width = get_size(image_or_video)  # type: ignore[arg-type]
328

329
        policy = self._policies[int(torch.randint(len(self._policies), ()))]
330

331
        for transform_id, probability, magnitude_idx in policy:
332
333
334
335
336
            if not torch.rand(()) <= probability:
                continue

            magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]

337
            magnitudes = magnitudes_fn(10, height, width)
338
339
340
341
342
343
344
            if magnitudes is not None:
                magnitude = float(magnitudes[magnitude_idx])
                if signed and torch.rand(()) <= 0.5:
                    magnitude *= -1
            else:
                magnitude = 0.0

345
            image_or_video = self._apply_image_or_video_transform(
346
                image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
347
            )
348

349
        return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
350
351
352


class RandAugment(_AutoAugmentBase):
353
    r"""RandAugment data augmentation method based on
354
355
356
    `"RandAugment: Practical automated data augmentation with a reduced search space"
    <https://arxiv.org/abs/1909.13719>`_.

357
358
359
    This transformation works on images and videos only.

    If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
360
361
362
363
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".

    Args:
364
365
366
367
        num_ops (int, optional): Number of augmentation transformations to apply sequentially.
        magnitude (int, optional): Magnitude for all the transformations.
        num_magnitude_bins (int, optional): The number of different magnitude values.
        interpolation (InterpolationMode, optional): Desired interpolation enum defined by
368
369
370
371
372
373
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
    """

374
    _v1_transform_cls = _transforms.RandAugment
375
    _AUGMENTATION_SPACE = {
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
        "Identity": (lambda num_bins, height, width: None, False),
        "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "TranslateX": (
            lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
            True,
        ),
        "TranslateY": (
            lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
            True,
        ),
        "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
        "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
392
        "Posterize": (
393
            lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
394
395
            False,
        ),
396
        "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
397
398
        "AutoContrast": (lambda num_bins, height, width: None, False),
        "Equalize": (lambda num_bins, height, width: None, False),
399
400
    }

401
402
403
404
405
    def __init__(
        self,
        num_ops: int = 2,
        magnitude: int = 9,
        num_magnitude_bins: int = 31,
406
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
407
        fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
408
409
    ) -> None:
        super().__init__(interpolation=interpolation, fill=fill)
410
411
412
413
        self.num_ops = num_ops
        self.magnitude = magnitude
        self.num_magnitude_bins = num_magnitude_bins

414
    def forward(self, *inputs: Any) -> Any:
415
        flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
Nicolas Hug's avatar
Nicolas Hug committed
416
        height, width = get_size(image_or_video)  # type: ignore[arg-type]
417
418
419

        for _ in range(self.num_ops):
            transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
420
            magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
421
            if magnitudes is not None:
422
                magnitude = float(magnitudes[self.magnitude])
423
424
425
426
                if signed and torch.rand(()) <= 0.5:
                    magnitude *= -1
            else:
                magnitude = 0.0
427
            image_or_video = self._apply_image_or_video_transform(
428
                image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
429
            )
430

431
        return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
432
433
434


class TrivialAugmentWide(_AutoAugmentBase):
435
    r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
436
437
    `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.

438
439
440
    This transformation works on images and videos only.

    If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
441
442
443
444
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".

    Args:
445
446
        num_magnitude_bins (int, optional): The number of different magnitude values.
        interpolation (InterpolationMode, optional): Desired interpolation enum defined by
447
448
449
450
451
452
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
    """

453
    _v1_transform_cls = _transforms.TrivialAugmentWide
454
    _AUGMENTATION_SPACE = {
455
456
457
458
459
460
461
462
463
464
        "Identity": (lambda num_bins, height, width: None, False),
        "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
        "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
        "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True),
        "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
465
        "Posterize": (
466
            lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(),
467
468
            False,
        ),
469
        "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
470
471
        "AutoContrast": (lambda num_bins, height, width: None, False),
        "Equalize": (lambda num_bins, height, width: None, False),
472
473
    }

474
475
476
    def __init__(
        self,
        num_magnitude_bins: int = 31,
477
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
478
        fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
479
480
    ):
        super().__init__(interpolation=interpolation, fill=fill)
481
482
        self.num_magnitude_bins = num_magnitude_bins

483
    def forward(self, *inputs: Any) -> Any:
484
        flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
Nicolas Hug's avatar
Nicolas Hug committed
485
        height, width = get_size(image_or_video)  # type: ignore[arg-type]
486
487
488

        transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)

489
        magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
490
491
492
493
494
495
496
        if magnitudes is not None:
            magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
            if signed and torch.rand(()) <= 0.5:
                magnitude *= -1
        else:
            magnitude = 0.0

497
        image_or_video = self._apply_image_or_video_transform(
498
            image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
499
        )
500
        return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
501
502
503


class AugMix(_AutoAugmentBase):
504
    r"""AugMix data augmentation method based on
505
506
    `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.

507
508
509
    This transformation works on images and videos only.

    If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
510
511
512
513
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".

    Args:
514
515
516
        severity (int, optional): The severity of base augmentation operators. Default is ``3``.
        mixture_width (int, optional): The number of augmentation chains. Default is ``3``.
        chain_depth (int, optional): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
517
            Default is ``-1``.
518
519
520
        alpha (float, optional): The hyperparameter for the probability distributions. Default is ``1.0``.
        all_ops (bool, optional): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
        interpolation (InterpolationMode, optional): Desired interpolation enum defined by
521
522
523
524
525
526
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
    """

527
528
    _v1_transform_cls = _transforms.AugMix

529
    _PARTIAL_AUGMENTATION_SPACE = {
530
531
532
533
534
        "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True),
        "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
        "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
535
        "Posterize": (
536
            lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
537
538
            False,
        ),
539
        "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
540
541
        "AutoContrast": (lambda num_bins, height, width: None, False),
        "Equalize": (lambda num_bins, height, width: None, False),
542
    }
543
    _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = {
544
        **_PARTIAL_AUGMENTATION_SPACE,
545
546
547
548
        "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
549
550
551
552
553
554
555
556
557
    }

    def __init__(
        self,
        severity: int = 3,
        mixture_width: int = 3,
        chain_depth: int = -1,
        alpha: float = 1.0,
        all_ops: bool = True,
558
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
559
        fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
560
    ) -> None:
561
        super().__init__(interpolation=interpolation, fill=fill)
562
563
564
565
566
567
568
569
570
571
572
573
574
        self._PARAMETER_MAX = 10
        if not (1 <= severity <= self._PARAMETER_MAX):
            raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
        self.severity = severity
        self.mixture_width = mixture_width
        self.chain_depth = chain_depth
        self.alpha = alpha
        self.all_ops = all_ops

    def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
        # Must be on a separate method so that we can overwrite it in tests.
        return torch._sample_dirichlet(params)

575
    def forward(self, *inputs: Any) -> Any:
576
        flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
Nicolas Hug's avatar
Nicolas Hug committed
577
        height, width = get_size(orig_image_or_video)  # type: ignore[arg-type]
578

579
580
        if isinstance(orig_image_or_video, torch.Tensor):
            image_or_video = orig_image_or_video
581
        else:  # isinstance(inpt, PIL.Image.Image):
582
            image_or_video = F.pil_to_tensor(orig_image_or_video)
583
584
585

        augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE

586
        orig_dims = list(image_or_video.shape)
587
        expected_ndim = 5 if isinstance(orig_image_or_video, tv_tensors.Video) else 4
588
        batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
589
590
        batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)

591
592
593
        # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
        # Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of
        # augmented image or video.
594
595
596
597
        m = self._sample_dirichlet(
            torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
        )

598
        # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.
599
600
        combined_weights = self._sample_dirichlet(
            torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
601
        ) * m[:, 1].reshape([batch_dims[0], -1])
602

603
        mix = m[:, 0].reshape(batch_dims) * batch
604
605
606
607
608
609
        for i in range(self.mixture_width):
            aug = batch
            depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
            for _ in range(depth):
                transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)

610
                magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
611
612
613
614
615
616
617
                if magnitudes is not None:
                    magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
                    if signed and torch.rand(()) <= 0.5:
                        magnitude *= -1
                else:
                    magnitude = 0.0

Nicolas Hug's avatar
Nicolas Hug committed
618
                aug = self._apply_image_or_video_transform(aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill)  # type: ignore[assignment]
619
            mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
620
        mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
621

622
623
        if isinstance(orig_image_or_video, (tv_tensors.Image, tv_tensors.Video)):
            mix = tv_tensors.wrap(mix, like=orig_image_or_video)
624
        elif isinstance(orig_image_or_video, PIL.Image.Image):
625
            mix = F.to_pil_image(mix)
626

627
        return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix)