"vscode:/vscode.git/clone" did not exist on "758b887ad19df35bc31f570e898696c90745a3af"
_auto_augment.py 31 KB
Newer Older
1
import math
2
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
3
4
5

import PIL.Image
import torch
6

7
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
8
from torchvision import datapoints, transforms as _transforms
9
from torchvision.transforms import _functional_tensor as _FT
10
11
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
Philip Meier's avatar
Philip Meier committed
12
from torchvision.transforms.v2.functional._meta import get_size
13
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
14

15
from ._utils import _get_fill, _setup_fill_arg
16
from .utils import check_type, is_simple_tensor
17
18


19
20
21
ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video]


22
23
class _AutoAugmentBase(Transform):
    def __init__(
24
25
        self,
        *,
26
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
27
        fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
28
29
    ) -> None:
        super().__init__()
30
        self.interpolation = _check_interpolation(interpolation)
vfdev's avatar
vfdev committed
31
        self.fill = _setup_fill_arg(fill)
32

33
    def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
34
35
36
37
        keys = tuple(dct.keys())
        key = keys[int(torch.randint(len(keys), ()))]
        return key, dct[key]

38
    def _flatten_and_extract_image_or_video(
39
        self,
40
        inputs: Any,
41
        unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBoxes, datapoints.Mask),
42
    ) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]:
43
        flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
44
        needs_transform_list = self._needs_transform_list(flat_inputs)
45

46
        image_or_videos = []
47
48
        for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)):
            if needs_transform and check_type(
49
50
51
52
53
54
55
56
                inpt,
                (
                    datapoints.Image,
                    PIL.Image.Image,
                    is_simple_tensor,
                    datapoints.Video,
                ),
            ):
57
                image_or_videos.append((idx, inpt))
58
59
60
            elif isinstance(inpt, unsupported_types):
                raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")

61
        if not image_or_videos:
62
            raise TypeError("Found no image in the sample.")
63
        if len(image_or_videos) > 1:
64
            raise TypeError(
65
66
                f"Auto augment transformations are only properly defined for a single image or video, "
                f"but found {len(image_or_videos)}."
67
68
            )

69
70
71
72
73
74
        idx, image_or_video = image_or_videos[0]
        return (flat_inputs, spec, idx), image_or_video

    def _unflatten_and_insert_image_or_video(
        self,
        flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
75
        image_or_video: ImageOrVideo,
76
77
78
79
    ) -> Any:
        flat_inputs, spec, idx = flat_inputs_with_spec
        flat_inputs[idx] = image_or_video
        return tree_unflatten(flat_inputs, spec)
80

81
    def _apply_image_or_video_transform(
82
        self,
83
        image: ImageOrVideo,
84
85
        transform_id: str,
        magnitude: float,
86
        interpolation: Union[InterpolationMode, int],
87
88
        fill: Dict[Union[Type, str], _FillTypeJIT],
    ) -> ImageOrVideo:
89
        fill_ = _get_fill(fill, type(image))
90

91
92
93
        if transform_id == "Identity":
            return image
        elif transform_id == "ShearX":
94
95
96
97
98
99
            # magnitude should be arctan(magnitude)
            # official autoaug: (1, level, 0, 0, 1, 0)
            # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
            # compared to
            # torchvision:      (1, tan(level), 0, 0, 1, 0)
            # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
100
            return F.affine(
101
102
103
104
                image,
                angle=0.0,
                translate=[0, 0],
                scale=1.0,
105
                shear=[math.degrees(math.atan(magnitude)), 0.0],
106
                interpolation=interpolation,
107
108
                fill=fill_,
                center=[0, 0],
109
110
            )
        elif transform_id == "ShearY":
111
112
            # magnitude should be arctan(magnitude)
            # See above
113
            return F.affine(
114
115
116
117
                image,
                angle=0.0,
                translate=[0, 0],
                scale=1.0,
118
                shear=[0.0, math.degrees(math.atan(magnitude))],
119
                interpolation=interpolation,
120
121
                fill=fill_,
                center=[0, 0],
122
123
            )
        elif transform_id == "TranslateX":
124
            return F.affine(
125
126
127
128
129
                image,
                angle=0.0,
                translate=[int(magnitude), 0],
                scale=1.0,
                interpolation=interpolation,
130
                shear=[0.0, 0.0],
131
                fill=fill_,
132
133
            )
        elif transform_id == "TranslateY":
134
            return F.affine(
135
136
137
138
139
                image,
                angle=0.0,
                translate=[0, int(magnitude)],
                scale=1.0,
                interpolation=interpolation,
140
                shear=[0.0, 0.0],
141
                fill=fill_,
142
143
            )
        elif transform_id == "Rotate":
144
            return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_)
145
        elif transform_id == "Brightness":
146
            return F.adjust_brightness(image, brightness_factor=1.0 + magnitude)
147
        elif transform_id == "Color":
148
            return F.adjust_saturation(image, saturation_factor=1.0 + magnitude)
149
        elif transform_id == "Contrast":
150
            return F.adjust_contrast(image, contrast_factor=1.0 + magnitude)
151
        elif transform_id == "Sharpness":
152
            return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude)
153
        elif transform_id == "Posterize":
154
            return F.posterize(image, bits=int(magnitude))
155
        elif transform_id == "Solarize":
Philip Meier's avatar
Philip Meier committed
156
            bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0
157
            return F.solarize(image, threshold=bound * magnitude)
158
        elif transform_id == "AutoContrast":
159
            return F.autocontrast(image)
160
        elif transform_id == "Equalize":
161
            return F.equalize(image)
162
        elif transform_id == "Invert":
163
            return F.invert(image)
164
165
        else:
            raise ValueError(f"No transform available for {transform_id}")
166
167
168


class AutoAugment(_AutoAugmentBase):
169
170
171
    r"""[BETA] AutoAugment data augmentation method based on
    `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.

172
    .. v2betastatus:: AutoAugment transform
173

174
175
176
    This transformation works on images and videos only.

    If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
177
178
179
180
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".

    Args:
181
        policy (AutoAugmentPolicy, optional): Desired policy enum defined by
182
            :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
183
        interpolation (InterpolationMode, optional): Desired interpolation enum defined by
184
185
186
187
188
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
    """
189
190
    _v1_transform_cls = _transforms.AutoAugment

191
    _AUGMENTATION_SPACE = {
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "TranslateX": (
            lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
            True,
        ),
        "TranslateY": (
            lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
            True,
        ),
        "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
        "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
207
        "Posterize": (
208
            lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
209
210
            False,
        ),
211
        "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
212
213
214
        "AutoContrast": (lambda num_bins, height, width: None, False),
        "Equalize": (lambda num_bins, height, width: None, False),
        "Invert": (lambda num_bins, height, width: None, False),
215
216
    }

217
218
219
    def __init__(
        self,
        policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
220
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
221
        fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
222
223
    ) -> None:
        super().__init__(interpolation=interpolation, fill=fill)
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        self.policy = policy
        self._policies = self._get_policies(policy)

    def _get_policies(
        self, policy: AutoAugmentPolicy
    ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
        if policy == AutoAugmentPolicy.IMAGENET:
            return [
                (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
                (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
                (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
                (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
                (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
                (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
                (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
                (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
                (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
                (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
                (("Rotate", 0.8, 8), ("Color", 0.4, 0)),
                (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
                (("Equalize", 0.0, None), ("Equalize", 0.8, None)),
                (("Invert", 0.6, None), ("Equalize", 1.0, None)),
                (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
                (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
                (("Color", 0.8, 8), ("Solarize", 0.8, 7)),
                (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
                (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
                (("Color", 0.4, 0), ("Equalize", 0.6, None)),
                (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
                (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
                (("Invert", 0.6, None), ("Equalize", 1.0, None)),
                (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
                (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
            ]
        elif policy == AutoAugmentPolicy.CIFAR10:
            return [
                (("Invert", 0.1, None), ("Contrast", 0.2, 6)),
                (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
                (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
                (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
                (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
                (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
                (("Color", 0.4, 3), ("Brightness", 0.6, 7)),
                (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
                (("Equalize", 0.6, None), ("Equalize", 0.5, None)),
                (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
                (("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
                (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
                (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
                (("Brightness", 0.9, 6), ("Color", 0.2, 8)),
                (("Solarize", 0.5, 2), ("Invert", 0.0, None)),
                (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
                (("Equalize", 0.2, None), ("Equalize", 0.6, None)),
                (("Color", 0.9, 9), ("Equalize", 0.6, None)),
                (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
                (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
                (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
                (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
                (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
                (("Equalize", 0.8, None), ("Invert", 0.1, None)),
                (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
            ]
        elif policy == AutoAugmentPolicy.SVHN:
            return [
                (("ShearX", 0.9, 4), ("Invert", 0.2, None)),
                (("ShearY", 0.9, 8), ("Invert", 0.7, None)),
                (("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
                (("Invert", 0.9, None), ("Equalize", 0.6, None)),
                (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
                (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
                (("ShearY", 0.9, 8), ("Invert", 0.4, None)),
                (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
                (("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
                (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
                (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
                (("ShearY", 0.8, 8), ("Invert", 0.7, None)),
                (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
                (("Invert", 0.9, None), ("Equalize", 0.6, None)),
                (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
                (("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
                (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
                (("Invert", 0.6, None), ("Rotate", 0.8, 4)),
                (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
                (("ShearX", 0.1, 6), ("Invert", 0.6, None)),
                (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
                (("ShearY", 0.8, 4), ("Invert", 0.8, None)),
                (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
                (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
                (("ShearX", 0.7, 2), ("Invert", 0.1, None)),
            ]
        else:
            raise ValueError(f"The provided policy {policy} is not recognized.")

317
    def forward(self, *inputs: Any) -> Any:
318
        flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
Philip Meier's avatar
Philip Meier committed
319
        height, width = get_size(image_or_video)
320

321
        policy = self._policies[int(torch.randint(len(self._policies), ()))]
322

323
        for transform_id, probability, magnitude_idx in policy:
324
325
326
327
328
            if not torch.rand(()) <= probability:
                continue

            magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]

329
            magnitudes = magnitudes_fn(10, height, width)
330
331
332
333
334
335
336
            if magnitudes is not None:
                magnitude = float(magnitudes[magnitude_idx])
                if signed and torch.rand(()) <= 0.5:
                    magnitude *= -1
            else:
                magnitude = 0.0

337
338
            image_or_video = self._apply_image_or_video_transform(
                image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
339
            )
340

341
        return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
342
343
344


class RandAugment(_AutoAugmentBase):
345
346
347
348
    r"""[BETA] RandAugment data augmentation method based on
    `"RandAugment: Practical automated data augmentation with a reduced search space"
    <https://arxiv.org/abs/1909.13719>`_.

349
    .. v2betastatus:: RandAugment transform
350

351
352
353
    This transformation works on images and videos only.

    If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
354
355
356
357
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".

    Args:
358
359
360
361
        num_ops (int, optional): Number of augmentation transformations to apply sequentially.
        magnitude (int, optional): Magnitude for all the transformations.
        num_magnitude_bins (int, optional): The number of different magnitude values.
        interpolation (InterpolationMode, optional): Desired interpolation enum defined by
362
363
364
365
366
367
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
    """

368
    _v1_transform_cls = _transforms.RandAugment
369
    _AUGMENTATION_SPACE = {
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
        "Identity": (lambda num_bins, height, width: None, False),
        "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "TranslateX": (
            lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
            True,
        ),
        "TranslateY": (
            lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
            True,
        ),
        "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
        "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
386
        "Posterize": (
387
            lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
388
389
            False,
        ),
390
        "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
391
392
        "AutoContrast": (lambda num_bins, height, width: None, False),
        "Equalize": (lambda num_bins, height, width: None, False),
393
394
    }

395
396
397
398
399
    def __init__(
        self,
        num_ops: int = 2,
        magnitude: int = 9,
        num_magnitude_bins: int = 31,
400
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
401
        fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
402
403
    ) -> None:
        super().__init__(interpolation=interpolation, fill=fill)
404
405
406
407
        self.num_ops = num_ops
        self.magnitude = magnitude
        self.num_magnitude_bins = num_magnitude_bins

408
    def forward(self, *inputs: Any) -> Any:
409
        flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
Philip Meier's avatar
Philip Meier committed
410
        height, width = get_size(image_or_video)
411
412
413

        for _ in range(self.num_ops):
            transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
414
            magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
415
            if magnitudes is not None:
416
                magnitude = float(magnitudes[self.magnitude])
417
418
419
420
                if signed and torch.rand(()) <= 0.5:
                    magnitude *= -1
            else:
                magnitude = 0.0
421
422
            image_or_video = self._apply_image_or_video_transform(
                image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
423
            )
424

425
        return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
426
427
428


class TrivialAugmentWide(_AutoAugmentBase):
429
430
431
    r"""[BETA] Dataset-independent data-augmentation with TrivialAugment Wide, as described in
    `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.

432
    .. v2betastatus:: TrivialAugmentWide transform
433

434
435
436
    This transformation works on images and videos only.

    If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
437
438
439
440
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".

    Args:
441
442
        num_magnitude_bins (int, optional): The number of different magnitude values.
        interpolation (InterpolationMode, optional): Desired interpolation enum defined by
443
444
445
446
447
448
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
    """

449
    _v1_transform_cls = _transforms.TrivialAugmentWide
450
    _AUGMENTATION_SPACE = {
451
452
453
454
455
456
457
458
459
460
        "Identity": (lambda num_bins, height, width: None, False),
        "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
        "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
        "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True),
        "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
461
        "Posterize": (
462
            lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(),
463
464
            False,
        ),
465
        "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
466
467
        "AutoContrast": (lambda num_bins, height, width: None, False),
        "Equalize": (lambda num_bins, height, width: None, False),
468
469
    }

470
471
472
    def __init__(
        self,
        num_magnitude_bins: int = 31,
473
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
474
        fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
475
476
    ):
        super().__init__(interpolation=interpolation, fill=fill)
477
478
        self.num_magnitude_bins = num_magnitude_bins

479
    def forward(self, *inputs: Any) -> Any:
480
        flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
Philip Meier's avatar
Philip Meier committed
481
        height, width = get_size(image_or_video)
482
483
484

        transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)

485
        magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
486
487
488
489
490
491
492
        if magnitudes is not None:
            magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
            if signed and torch.rand(()) <= 0.5:
                magnitude *= -1
        else:
            magnitude = 0.0

493
494
        image_or_video = self._apply_image_or_video_transform(
            image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
495
        )
496
        return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
497
498
499


class AugMix(_AutoAugmentBase):
500
501
502
    r"""[BETA] AugMix data augmentation method based on
    `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.

503
    .. v2betastatus:: AugMix transform
504

505
506
507
    This transformation works on images and videos only.

    If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
508
509
510
511
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".

    Args:
512
513
514
        severity (int, optional): The severity of base augmentation operators. Default is ``3``.
        mixture_width (int, optional): The number of augmentation chains. Default is ``3``.
        chain_depth (int, optional): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
515
            Default is ``-1``.
516
517
518
        alpha (float, optional): The hyperparameter for the probability distributions. Default is ``1.0``.
        all_ops (bool, optional): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
        interpolation (InterpolationMode, optional): Desired interpolation enum defined by
519
520
521
522
523
524
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
    """

525
526
    _v1_transform_cls = _transforms.AugMix

527
    _PARTIAL_AUGMENTATION_SPACE = {
528
529
530
531
532
        "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True),
        "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
        "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
533
        "Posterize": (
534
            lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
535
536
            False,
        ),
537
        "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
538
539
        "AutoContrast": (lambda num_bins, height, width: None, False),
        "Equalize": (lambda num_bins, height, width: None, False),
540
    }
541
    _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = {
542
        **_PARTIAL_AUGMENTATION_SPACE,
543
544
545
546
        "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
547
548
549
550
551
552
553
554
555
    }

    def __init__(
        self,
        severity: int = 3,
        mixture_width: int = 3,
        chain_depth: int = -1,
        alpha: float = 1.0,
        all_ops: bool = True,
556
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
557
        fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
558
    ) -> None:
559
        super().__init__(interpolation=interpolation, fill=fill)
560
561
562
563
564
565
566
567
568
569
570
571
572
        self._PARAMETER_MAX = 10
        if not (1 <= severity <= self._PARAMETER_MAX):
            raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
        self.severity = severity
        self.mixture_width = mixture_width
        self.chain_depth = chain_depth
        self.alpha = alpha
        self.all_ops = all_ops

    def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
        # Must be on a separate method so that we can overwrite it in tests.
        return torch._sample_dirichlet(params)

573
    def forward(self, *inputs: Any) -> Any:
574
        flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
Philip Meier's avatar
Philip Meier committed
575
        height, width = get_size(orig_image_or_video)
576

577
578
        if isinstance(orig_image_or_video, torch.Tensor):
            image_or_video = orig_image_or_video
579
        else:  # isinstance(inpt, PIL.Image.Image):
580
            image_or_video = F.pil_to_tensor(orig_image_or_video)
581
582
583

        augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE

584
        orig_dims = list(image_or_video.shape)
585
        expected_ndim = 5 if isinstance(orig_image_or_video, datapoints.Video) else 4
586
        batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
587
588
        batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)

589
590
591
        # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
        # Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of
        # augmented image or video.
592
593
594
595
        m = self._sample_dirichlet(
            torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
        )

596
        # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.
597
598
        combined_weights = self._sample_dirichlet(
            torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
599
        ) * m[:, 1].reshape([batch_dims[0], -1])
600

601
        mix = m[:, 0].reshape(batch_dims) * batch
602
603
604
605
606
607
        for i in range(self.mixture_width):
            aug = batch
            depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
            for _ in range(depth):
                transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)

608
                magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
609
610
611
612
613
614
615
                if magnitudes is not None:
                    magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
                    if signed and torch.rand(()) <= 0.5:
                        magnitude *= -1
                else:
                    magnitude = 0.0

616
                aug = self._apply_image_or_video_transform(
617
                    aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
618
                )
619
            mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
620
        mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
621

622
        if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)):
623
            mix = orig_image_or_video.wrap_like(orig_image_or_video, mix)  # type: ignore[arg-type]
624
        elif isinstance(orig_image_or_video, PIL.Image.Image):
625
            mix = F.to_image_pil(mix)
626

627
        return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix)