"dialogctrl/ner/src/model.py" did not exist on "90e0a0dd08159e1c95f4f9d99bb8687f327d36c3"
_auto_augment.py 31.3 KB
Newer Older
1
import math
2
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
3
4
5

import PIL.Image
import torch
6

7
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
8
from torchvision import datapoints, transforms as _transforms
9
from torchvision.transforms import _functional_tensor as _FT
10
11
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
Philip Meier's avatar
Philip Meier committed
12
from torchvision.transforms.v2.functional._meta import get_size
13
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
14

15
from ._utils import _get_fill, _setup_fill_arg
16
from .utils import check_type, is_pure_tensor
17
18


19
20
21
ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video]


22
23
class _AutoAugmentBase(Transform):
    def __init__(
24
25
        self,
        *,
26
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
27
        fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
28
29
    ) -> None:
        super().__init__()
30
        self.interpolation = _check_interpolation(interpolation)
31
32
33
34
35
36
37
38
39
40
        self.fill = fill
        self._fill = _setup_fill_arg(fill)

    def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
        params = super()._extract_params_for_v1_transform()

        if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
            raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")

        return params
41

42
    def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
43
44
45
46
        keys = tuple(dct.keys())
        key = keys[int(torch.randint(len(keys), ()))]
        return key, dct[key]

47
    def _flatten_and_extract_image_or_video(
48
        self,
49
        inputs: Any,
50
        unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBoxes, datapoints.Mask),
51
    ) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]:
52
        flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
53
        needs_transform_list = self._needs_transform_list(flat_inputs)
54

55
        image_or_videos = []
56
57
        for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)):
            if needs_transform and check_type(
58
59
60
61
                inpt,
                (
                    datapoints.Image,
                    PIL.Image.Image,
62
                    is_pure_tensor,
63
64
65
                    datapoints.Video,
                ),
            ):
66
                image_or_videos.append((idx, inpt))
67
68
69
            elif isinstance(inpt, unsupported_types):
                raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")

70
        if not image_or_videos:
71
            raise TypeError("Found no image in the sample.")
72
        if len(image_or_videos) > 1:
73
            raise TypeError(
74
75
                f"Auto augment transformations are only properly defined for a single image or video, "
                f"but found {len(image_or_videos)}."
76
77
            )

78
79
80
81
82
83
        idx, image_or_video = image_or_videos[0]
        return (flat_inputs, spec, idx), image_or_video

    def _unflatten_and_insert_image_or_video(
        self,
        flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
84
        image_or_video: ImageOrVideo,
85
86
87
88
    ) -> Any:
        flat_inputs, spec, idx = flat_inputs_with_spec
        flat_inputs[idx] = image_or_video
        return tree_unflatten(flat_inputs, spec)
89

90
    def _apply_image_or_video_transform(
91
        self,
92
        image: ImageOrVideo,
93
94
        transform_id: str,
        magnitude: float,
95
        interpolation: Union[InterpolationMode, int],
96
97
        fill: Dict[Union[Type, str], _FillTypeJIT],
    ) -> ImageOrVideo:
98
        fill_ = _get_fill(fill, type(image))
99

100
101
102
        if transform_id == "Identity":
            return image
        elif transform_id == "ShearX":
103
104
105
106
107
108
            # magnitude should be arctan(magnitude)
            # official autoaug: (1, level, 0, 0, 1, 0)
            # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
            # compared to
            # torchvision:      (1, tan(level), 0, 0, 1, 0)
            # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
109
            return F.affine(
110
111
112
113
                image,
                angle=0.0,
                translate=[0, 0],
                scale=1.0,
114
                shear=[math.degrees(math.atan(magnitude)), 0.0],
115
                interpolation=interpolation,
116
117
                fill=fill_,
                center=[0, 0],
118
119
            )
        elif transform_id == "ShearY":
120
121
            # magnitude should be arctan(magnitude)
            # See above
122
            return F.affine(
123
124
125
126
                image,
                angle=0.0,
                translate=[0, 0],
                scale=1.0,
127
                shear=[0.0, math.degrees(math.atan(magnitude))],
128
                interpolation=interpolation,
129
130
                fill=fill_,
                center=[0, 0],
131
132
            )
        elif transform_id == "TranslateX":
133
            return F.affine(
134
135
136
137
138
                image,
                angle=0.0,
                translate=[int(magnitude), 0],
                scale=1.0,
                interpolation=interpolation,
139
                shear=[0.0, 0.0],
140
                fill=fill_,
141
142
            )
        elif transform_id == "TranslateY":
143
            return F.affine(
144
145
146
147
148
                image,
                angle=0.0,
                translate=[0, int(magnitude)],
                scale=1.0,
                interpolation=interpolation,
149
                shear=[0.0, 0.0],
150
                fill=fill_,
151
152
            )
        elif transform_id == "Rotate":
153
            return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_)
154
        elif transform_id == "Brightness":
155
            return F.adjust_brightness(image, brightness_factor=1.0 + magnitude)
156
        elif transform_id == "Color":
157
            return F.adjust_saturation(image, saturation_factor=1.0 + magnitude)
158
        elif transform_id == "Contrast":
159
            return F.adjust_contrast(image, contrast_factor=1.0 + magnitude)
160
        elif transform_id == "Sharpness":
161
            return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude)
162
        elif transform_id == "Posterize":
163
            return F.posterize(image, bits=int(magnitude))
164
        elif transform_id == "Solarize":
Philip Meier's avatar
Philip Meier committed
165
            bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0
166
            return F.solarize(image, threshold=bound * magnitude)
167
        elif transform_id == "AutoContrast":
168
            return F.autocontrast(image)
169
        elif transform_id == "Equalize":
170
            return F.equalize(image)
171
        elif transform_id == "Invert":
172
            return F.invert(image)
173
174
        else:
            raise ValueError(f"No transform available for {transform_id}")
175
176
177


class AutoAugment(_AutoAugmentBase):
178
179
180
    r"""[BETA] AutoAugment data augmentation method based on
    `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.

181
    .. v2betastatus:: AutoAugment transform
182

183
184
185
    This transformation works on images and videos only.

    If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
186
187
188
189
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".

    Args:
190
        policy (AutoAugmentPolicy, optional): Desired policy enum defined by
191
            :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
192
        interpolation (InterpolationMode, optional): Desired interpolation enum defined by
193
194
195
196
197
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
    """
198
199
    _v1_transform_cls = _transforms.AutoAugment

200
    _AUGMENTATION_SPACE = {
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "TranslateX": (
            lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
            True,
        ),
        "TranslateY": (
            lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
            True,
        ),
        "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
        "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
216
        "Posterize": (
217
            lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
218
219
            False,
        ),
220
        "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
221
222
223
        "AutoContrast": (lambda num_bins, height, width: None, False),
        "Equalize": (lambda num_bins, height, width: None, False),
        "Invert": (lambda num_bins, height, width: None, False),
224
225
    }

226
227
228
    def __init__(
        self,
        policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
229
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
230
        fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
231
232
    ) -> None:
        super().__init__(interpolation=interpolation, fill=fill)
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
        self.policy = policy
        self._policies = self._get_policies(policy)

    def _get_policies(
        self, policy: AutoAugmentPolicy
    ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
        if policy == AutoAugmentPolicy.IMAGENET:
            return [
                (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
                (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
                (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
                (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
                (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
                (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
                (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
                (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
                (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
                (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
                (("Rotate", 0.8, 8), ("Color", 0.4, 0)),
                (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
                (("Equalize", 0.0, None), ("Equalize", 0.8, None)),
                (("Invert", 0.6, None), ("Equalize", 1.0, None)),
                (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
                (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
                (("Color", 0.8, 8), ("Solarize", 0.8, 7)),
                (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
                (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
                (("Color", 0.4, 0), ("Equalize", 0.6, None)),
                (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
                (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
                (("Invert", 0.6, None), ("Equalize", 1.0, None)),
                (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
                (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
            ]
        elif policy == AutoAugmentPolicy.CIFAR10:
            return [
                (("Invert", 0.1, None), ("Contrast", 0.2, 6)),
                (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
                (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
                (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
                (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
                (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
                (("Color", 0.4, 3), ("Brightness", 0.6, 7)),
                (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
                (("Equalize", 0.6, None), ("Equalize", 0.5, None)),
                (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
                (("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
                (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
                (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
                (("Brightness", 0.9, 6), ("Color", 0.2, 8)),
                (("Solarize", 0.5, 2), ("Invert", 0.0, None)),
                (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
                (("Equalize", 0.2, None), ("Equalize", 0.6, None)),
                (("Color", 0.9, 9), ("Equalize", 0.6, None)),
                (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
                (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
                (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
                (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
                (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
                (("Equalize", 0.8, None), ("Invert", 0.1, None)),
                (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
            ]
        elif policy == AutoAugmentPolicy.SVHN:
            return [
                (("ShearX", 0.9, 4), ("Invert", 0.2, None)),
                (("ShearY", 0.9, 8), ("Invert", 0.7, None)),
                (("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
                (("Invert", 0.9, None), ("Equalize", 0.6, None)),
                (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
                (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
                (("ShearY", 0.9, 8), ("Invert", 0.4, None)),
                (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
                (("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
                (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
                (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
                (("ShearY", 0.8, 8), ("Invert", 0.7, None)),
                (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
                (("Invert", 0.9, None), ("Equalize", 0.6, None)),
                (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
                (("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
                (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
                (("Invert", 0.6, None), ("Rotate", 0.8, 4)),
                (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
                (("ShearX", 0.1, 6), ("Invert", 0.6, None)),
                (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
                (("ShearY", 0.8, 4), ("Invert", 0.8, None)),
                (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
                (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
                (("ShearX", 0.7, 2), ("Invert", 0.1, None)),
            ]
        else:
            raise ValueError(f"The provided policy {policy} is not recognized.")

326
    def forward(self, *inputs: Any) -> Any:
327
        flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
Philip Meier's avatar
Philip Meier committed
328
        height, width = get_size(image_or_video)
329

330
        policy = self._policies[int(torch.randint(len(self._policies), ()))]
331

332
        for transform_id, probability, magnitude_idx in policy:
333
334
335
336
337
            if not torch.rand(()) <= probability:
                continue

            magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]

338
            magnitudes = magnitudes_fn(10, height, width)
339
340
341
342
343
344
345
            if magnitudes is not None:
                magnitude = float(magnitudes[magnitude_idx])
                if signed and torch.rand(()) <= 0.5:
                    magnitude *= -1
            else:
                magnitude = 0.0

346
            image_or_video = self._apply_image_or_video_transform(
347
                image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
348
            )
349

350
        return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
351
352
353


class RandAugment(_AutoAugmentBase):
354
355
356
357
    r"""[BETA] RandAugment data augmentation method based on
    `"RandAugment: Practical automated data augmentation with a reduced search space"
    <https://arxiv.org/abs/1909.13719>`_.

358
    .. v2betastatus:: RandAugment transform
359

360
361
362
    This transformation works on images and videos only.

    If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
363
364
365
366
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".

    Args:
367
368
369
370
        num_ops (int, optional): Number of augmentation transformations to apply sequentially.
        magnitude (int, optional): Magnitude for all the transformations.
        num_magnitude_bins (int, optional): The number of different magnitude values.
        interpolation (InterpolationMode, optional): Desired interpolation enum defined by
371
372
373
374
375
376
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
    """

377
    _v1_transform_cls = _transforms.RandAugment
378
    _AUGMENTATION_SPACE = {
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
        "Identity": (lambda num_bins, height, width: None, False),
        "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "TranslateX": (
            lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
            True,
        ),
        "TranslateY": (
            lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
            True,
        ),
        "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
        "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
395
        "Posterize": (
396
            lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
397
398
            False,
        ),
399
        "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
400
401
        "AutoContrast": (lambda num_bins, height, width: None, False),
        "Equalize": (lambda num_bins, height, width: None, False),
402
403
    }

404
405
406
407
408
    def __init__(
        self,
        num_ops: int = 2,
        magnitude: int = 9,
        num_magnitude_bins: int = 31,
409
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
410
        fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
411
412
    ) -> None:
        super().__init__(interpolation=interpolation, fill=fill)
413
414
415
416
        self.num_ops = num_ops
        self.magnitude = magnitude
        self.num_magnitude_bins = num_magnitude_bins

417
    def forward(self, *inputs: Any) -> Any:
418
        flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
Philip Meier's avatar
Philip Meier committed
419
        height, width = get_size(image_or_video)
420
421
422

        for _ in range(self.num_ops):
            transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
423
            magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
424
            if magnitudes is not None:
425
                magnitude = float(magnitudes[self.magnitude])
426
427
428
429
                if signed and torch.rand(()) <= 0.5:
                    magnitude *= -1
            else:
                magnitude = 0.0
430
            image_or_video = self._apply_image_or_video_transform(
431
                image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
432
            )
433

434
        return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
435
436
437


class TrivialAugmentWide(_AutoAugmentBase):
438
439
440
    r"""[BETA] Dataset-independent data-augmentation with TrivialAugment Wide, as described in
    `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.

441
    .. v2betastatus:: TrivialAugmentWide transform
442

443
444
445
    This transformation works on images and videos only.

    If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
446
447
448
449
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".

    Args:
450
451
        num_magnitude_bins (int, optional): The number of different magnitude values.
        interpolation (InterpolationMode, optional): Desired interpolation enum defined by
452
453
454
455
456
457
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
    """

458
    _v1_transform_cls = _transforms.TrivialAugmentWide
459
    _AUGMENTATION_SPACE = {
460
461
462
463
464
465
466
467
468
469
        "Identity": (lambda num_bins, height, width: None, False),
        "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
        "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
        "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True),
        "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
470
        "Posterize": (
471
            lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(),
472
473
            False,
        ),
474
        "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
475
476
        "AutoContrast": (lambda num_bins, height, width: None, False),
        "Equalize": (lambda num_bins, height, width: None, False),
477
478
    }

479
480
481
    def __init__(
        self,
        num_magnitude_bins: int = 31,
482
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
483
        fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
484
485
    ):
        super().__init__(interpolation=interpolation, fill=fill)
486
487
        self.num_magnitude_bins = num_magnitude_bins

488
    def forward(self, *inputs: Any) -> Any:
489
        flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
Philip Meier's avatar
Philip Meier committed
490
        height, width = get_size(image_or_video)
491
492
493

        transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)

494
        magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
495
496
497
498
499
500
501
        if magnitudes is not None:
            magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
            if signed and torch.rand(()) <= 0.5:
                magnitude *= -1
        else:
            magnitude = 0.0

502
        image_or_video = self._apply_image_or_video_transform(
503
            image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
504
        )
505
        return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
506
507
508


class AugMix(_AutoAugmentBase):
509
510
511
    r"""[BETA] AugMix data augmentation method based on
    `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.

512
    .. v2betastatus:: AugMix transform
513

514
515
516
    This transformation works on images and videos only.

    If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
517
518
519
520
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".

    Args:
521
522
523
        severity (int, optional): The severity of base augmentation operators. Default is ``3``.
        mixture_width (int, optional): The number of augmentation chains. Default is ``3``.
        chain_depth (int, optional): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
524
            Default is ``-1``.
525
526
527
        alpha (float, optional): The hyperparameter for the probability distributions. Default is ``1.0``.
        all_ops (bool, optional): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
        interpolation (InterpolationMode, optional): Desired interpolation enum defined by
528
529
530
531
532
533
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
    """

534
535
    _v1_transform_cls = _transforms.AugMix

536
    _PARTIAL_AUGMENTATION_SPACE = {
537
538
539
540
541
        "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True),
        "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
        "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
542
        "Posterize": (
543
            lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
544
545
            False,
        ),
546
        "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
547
548
        "AutoContrast": (lambda num_bins, height, width: None, False),
        "Equalize": (lambda num_bins, height, width: None, False),
549
    }
550
    _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = {
551
        **_PARTIAL_AUGMENTATION_SPACE,
552
553
554
555
        "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
556
557
558
559
560
561
562
563
564
    }

    def __init__(
        self,
        severity: int = 3,
        mixture_width: int = 3,
        chain_depth: int = -1,
        alpha: float = 1.0,
        all_ops: bool = True,
565
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
566
        fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
567
    ) -> None:
568
        super().__init__(interpolation=interpolation, fill=fill)
569
570
571
572
573
574
575
576
577
578
579
580
581
        self._PARAMETER_MAX = 10
        if not (1 <= severity <= self._PARAMETER_MAX):
            raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
        self.severity = severity
        self.mixture_width = mixture_width
        self.chain_depth = chain_depth
        self.alpha = alpha
        self.all_ops = all_ops

    def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
        # Must be on a separate method so that we can overwrite it in tests.
        return torch._sample_dirichlet(params)

582
    def forward(self, *inputs: Any) -> Any:
583
        flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
Philip Meier's avatar
Philip Meier committed
584
        height, width = get_size(orig_image_or_video)
585

586
587
        if isinstance(orig_image_or_video, torch.Tensor):
            image_or_video = orig_image_or_video
588
        else:  # isinstance(inpt, PIL.Image.Image):
589
            image_or_video = F.pil_to_tensor(orig_image_or_video)
590
591
592

        augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE

593
        orig_dims = list(image_or_video.shape)
594
        expected_ndim = 5 if isinstance(orig_image_or_video, datapoints.Video) else 4
595
        batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
596
597
        batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)

598
599
600
        # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
        # Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of
        # augmented image or video.
601
602
603
604
        m = self._sample_dirichlet(
            torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
        )

605
        # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.
606
607
        combined_weights = self._sample_dirichlet(
            torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
608
        ) * m[:, 1].reshape([batch_dims[0], -1])
609

610
        mix = m[:, 0].reshape(batch_dims) * batch
611
612
613
614
615
616
        for i in range(self.mixture_width):
            aug = batch
            depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
            for _ in range(depth):
                transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)

617
                magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
618
619
620
621
622
623
624
                if magnitudes is not None:
                    magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
                    if signed and torch.rand(()) <= 0.5:
                        magnitude *= -1
                else:
                    magnitude = 0.0

625
                aug = self._apply_image_or_video_transform(
626
                    aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
627
                )
628
            mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
629
        mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
630

631
        if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)):
632
            mix = datapoints.wrap(mix, like=orig_image_or_video)
633
        elif isinstance(orig_image_or_video, PIL.Image.Image):
634
            mix = F.to_pil_image(mix)
635

636
        return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix)