_auto_augment.py 25.7 KB
Newer Older
1
import math
2
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
3
4
5

import PIL.Image
import torch
6

7
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
8
from torchvision import datapoints, transforms as _transforms
9
from torchvision.transforms import _functional_tensor as _FT
10
11
12
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.functional._meta import get_spatial_size
13

14
from ._utils import _setup_fill_arg
15
from .utils import check_type, is_simple_tensor
16
17
18
19


class _AutoAugmentBase(Transform):
    def __init__(
20
21
        self,
        *,
22
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
23
        fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
24
25
    ) -> None:
        super().__init__()
26
        self.interpolation = _check_interpolation(interpolation)
vfdev's avatar
vfdev committed
27
        self.fill = _setup_fill_arg(fill)
28

29
    def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
30
31
32
33
        keys = tuple(dct.keys())
        key = keys[int(torch.randint(len(keys), ()))]
        return key, dct[key]

34
    def _flatten_and_extract_image_or_video(
35
        self,
36
        inputs: Any,
37
        unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBox, datapoints.Mask),
Philip Meier's avatar
Philip Meier committed
38
    ) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints._ImageType, datapoints._VideoType]]:
39
        flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
40
        needs_transform_list = self._needs_transform_list(flat_inputs)
41

42
        image_or_videos = []
43
44
        for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)):
            if needs_transform and check_type(
45
46
47
48
49
50
51
52
                inpt,
                (
                    datapoints.Image,
                    PIL.Image.Image,
                    is_simple_tensor,
                    datapoints.Video,
                ),
            ):
53
                image_or_videos.append((idx, inpt))
54
55
56
            elif isinstance(inpt, unsupported_types):
                raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")

57
        if not image_or_videos:
58
            raise TypeError("Found no image in the sample.")
59
        if len(image_or_videos) > 1:
60
            raise TypeError(
61
62
                f"Auto augment transformations are only properly defined for a single image or video, "
                f"but found {len(image_or_videos)}."
63
64
            )

65
66
67
68
69
70
        idx, image_or_video = image_or_videos[0]
        return (flat_inputs, spec, idx), image_or_video

    def _unflatten_and_insert_image_or_video(
        self,
        flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
Philip Meier's avatar
Philip Meier committed
71
        image_or_video: Union[datapoints._ImageType, datapoints._VideoType],
72
73
74
75
    ) -> Any:
        flat_inputs, spec, idx = flat_inputs_with_spec
        flat_inputs[idx] = image_or_video
        return tree_unflatten(flat_inputs, spec)
76

77
    def _apply_image_or_video_transform(
78
        self,
Philip Meier's avatar
Philip Meier committed
79
        image: Union[datapoints._ImageType, datapoints._VideoType],
80
81
        transform_id: str,
        magnitude: float,
82
        interpolation: Union[InterpolationMode, int],
Philip Meier's avatar
Philip Meier committed
83
84
        fill: Dict[Type, datapoints._FillTypeJIT],
    ) -> Union[datapoints._ImageType, datapoints._VideoType]:
vfdev's avatar
vfdev committed
85
        fill_ = fill[type(image)]
86

87
88
89
        if transform_id == "Identity":
            return image
        elif transform_id == "ShearX":
90
91
92
93
94
95
            # magnitude should be arctan(magnitude)
            # official autoaug: (1, level, 0, 0, 1, 0)
            # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
            # compared to
            # torchvision:      (1, tan(level), 0, 0, 1, 0)
            # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
96
            return F.affine(
97
98
99
100
                image,
                angle=0.0,
                translate=[0, 0],
                scale=1.0,
101
                shear=[math.degrees(math.atan(magnitude)), 0.0],
102
                interpolation=interpolation,
103
104
                fill=fill_,
                center=[0, 0],
105
106
            )
        elif transform_id == "ShearY":
107
108
            # magnitude should be arctan(magnitude)
            # See above
109
            return F.affine(
110
111
112
113
                image,
                angle=0.0,
                translate=[0, 0],
                scale=1.0,
114
                shear=[0.0, math.degrees(math.atan(magnitude))],
115
                interpolation=interpolation,
116
117
                fill=fill_,
                center=[0, 0],
118
119
            )
        elif transform_id == "TranslateX":
120
            return F.affine(
121
122
123
124
125
                image,
                angle=0.0,
                translate=[int(magnitude), 0],
                scale=1.0,
                interpolation=interpolation,
126
                shear=[0.0, 0.0],
127
                fill=fill_,
128
129
            )
        elif transform_id == "TranslateY":
130
            return F.affine(
131
132
133
134
135
                image,
                angle=0.0,
                translate=[0, int(magnitude)],
                scale=1.0,
                interpolation=interpolation,
136
                shear=[0.0, 0.0],
137
                fill=fill_,
138
139
            )
        elif transform_id == "Rotate":
140
            return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_)
141
        elif transform_id == "Brightness":
142
            return F.adjust_brightness(image, brightness_factor=1.0 + magnitude)
143
        elif transform_id == "Color":
144
            return F.adjust_saturation(image, saturation_factor=1.0 + magnitude)
145
        elif transform_id == "Contrast":
146
            return F.adjust_contrast(image, contrast_factor=1.0 + magnitude)
147
        elif transform_id == "Sharpness":
148
            return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude)
149
        elif transform_id == "Posterize":
150
            return F.posterize(image, bits=int(magnitude))
151
        elif transform_id == "Solarize":
Philip Meier's avatar
Philip Meier committed
152
            bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0
153
            return F.solarize(image, threshold=bound * magnitude)
154
        elif transform_id == "AutoContrast":
155
            return F.autocontrast(image)
156
        elif transform_id == "Equalize":
157
            return F.equalize(image)
158
        elif transform_id == "Invert":
159
            return F.invert(image)
160
161
        else:
            raise ValueError(f"No transform available for {transform_id}")
162
163
164


class AutoAugment(_AutoAugmentBase):
165
166
    _v1_transform_cls = _transforms.AutoAugment

167
    _AUGMENTATION_SPACE = {
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "TranslateX": (
            lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
            True,
        ),
        "TranslateY": (
            lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
            True,
        ),
        "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
        "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
183
        "Posterize": (
184
            lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
185
186
            False,
        ),
187
        "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
188
189
190
        "AutoContrast": (lambda num_bins, height, width: None, False),
        "Equalize": (lambda num_bins, height, width: None, False),
        "Invert": (lambda num_bins, height, width: None, False),
191
192
    }

193
194
195
    def __init__(
        self,
        policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
196
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
197
        fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
198
199
    ) -> None:
        super().__init__(interpolation=interpolation, fill=fill)
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
        self.policy = policy
        self._policies = self._get_policies(policy)

    def _get_policies(
        self, policy: AutoAugmentPolicy
    ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
        if policy == AutoAugmentPolicy.IMAGENET:
            return [
                (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
                (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
                (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
                (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
                (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
                (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
                (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
                (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
                (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
                (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
                (("Rotate", 0.8, 8), ("Color", 0.4, 0)),
                (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
                (("Equalize", 0.0, None), ("Equalize", 0.8, None)),
                (("Invert", 0.6, None), ("Equalize", 1.0, None)),
                (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
                (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
                (("Color", 0.8, 8), ("Solarize", 0.8, 7)),
                (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
                (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
                (("Color", 0.4, 0), ("Equalize", 0.6, None)),
                (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
                (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
                (("Invert", 0.6, None), ("Equalize", 1.0, None)),
                (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
                (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
            ]
        elif policy == AutoAugmentPolicy.CIFAR10:
            return [
                (("Invert", 0.1, None), ("Contrast", 0.2, 6)),
                (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
                (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
                (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
                (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
                (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
                (("Color", 0.4, 3), ("Brightness", 0.6, 7)),
                (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
                (("Equalize", 0.6, None), ("Equalize", 0.5, None)),
                (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
                (("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
                (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
                (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
                (("Brightness", 0.9, 6), ("Color", 0.2, 8)),
                (("Solarize", 0.5, 2), ("Invert", 0.0, None)),
                (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
                (("Equalize", 0.2, None), ("Equalize", 0.6, None)),
                (("Color", 0.9, 9), ("Equalize", 0.6, None)),
                (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
                (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
                (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
                (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
                (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
                (("Equalize", 0.8, None), ("Invert", 0.1, None)),
                (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
            ]
        elif policy == AutoAugmentPolicy.SVHN:
            return [
                (("ShearX", 0.9, 4), ("Invert", 0.2, None)),
                (("ShearY", 0.9, 8), ("Invert", 0.7, None)),
                (("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
                (("Invert", 0.9, None), ("Equalize", 0.6, None)),
                (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
                (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
                (("ShearY", 0.9, 8), ("Invert", 0.4, None)),
                (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
                (("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
                (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
                (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
                (("ShearY", 0.8, 8), ("Invert", 0.7, None)),
                (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
                (("Invert", 0.9, None), ("Equalize", 0.6, None)),
                (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
                (("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
                (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
                (("Invert", 0.6, None), ("Rotate", 0.8, 4)),
                (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
                (("ShearX", 0.1, 6), ("Invert", 0.6, None)),
                (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
                (("ShearY", 0.8, 4), ("Invert", 0.8, None)),
                (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
                (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
                (("ShearX", 0.7, 2), ("Invert", 0.1, None)),
            ]
        else:
            raise ValueError(f"The provided policy {policy} is not recognized.")

293
    def forward(self, *inputs: Any) -> Any:
294
        flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
295
        height, width = get_spatial_size(image_or_video)
296

297
        policy = self._policies[int(torch.randint(len(self._policies), ()))]
298

299
        for transform_id, probability, magnitude_idx in policy:
300
301
302
303
304
            if not torch.rand(()) <= probability:
                continue

            magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]

305
            magnitudes = magnitudes_fn(10, height, width)
306
307
308
309
310
311
312
            if magnitudes is not None:
                magnitude = float(magnitudes[magnitude_idx])
                if signed and torch.rand(()) <= 0.5:
                    magnitude *= -1
            else:
                magnitude = 0.0

313
314
            image_or_video = self._apply_image_or_video_transform(
                image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
315
            )
316

317
        return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
318
319
320


class RandAugment(_AutoAugmentBase):
321
    _v1_transform_cls = _transforms.RandAugment
322
    _AUGMENTATION_SPACE = {
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
        "Identity": (lambda num_bins, height, width: None, False),
        "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "TranslateX": (
            lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
            True,
        ),
        "TranslateY": (
            lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
            True,
        ),
        "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
        "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
339
        "Posterize": (
340
            lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
341
342
            False,
        ),
343
        "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
344
345
        "AutoContrast": (lambda num_bins, height, width: None, False),
        "Equalize": (lambda num_bins, height, width: None, False),
346
347
    }

348
349
350
351
352
    def __init__(
        self,
        num_ops: int = 2,
        magnitude: int = 9,
        num_magnitude_bins: int = 31,
353
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
354
        fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
355
356
    ) -> None:
        super().__init__(interpolation=interpolation, fill=fill)
357
358
359
360
        self.num_ops = num_ops
        self.magnitude = magnitude
        self.num_magnitude_bins = num_magnitude_bins

361
    def forward(self, *inputs: Any) -> Any:
362
        flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
363
        height, width = get_spatial_size(image_or_video)
364
365
366

        for _ in range(self.num_ops):
            transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
367
            magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
368
            if magnitudes is not None:
369
                magnitude = float(magnitudes[self.magnitude])
370
371
372
373
                if signed and torch.rand(()) <= 0.5:
                    magnitude *= -1
            else:
                magnitude = 0.0
374
375
            image_or_video = self._apply_image_or_video_transform(
                image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
376
            )
377

378
        return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
379
380
381


class TrivialAugmentWide(_AutoAugmentBase):
382
    _v1_transform_cls = _transforms.TrivialAugmentWide
383
    _AUGMENTATION_SPACE = {
384
385
386
387
388
389
390
391
392
393
        "Identity": (lambda num_bins, height, width: None, False),
        "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
        "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
        "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True),
        "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
        "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
394
        "Posterize": (
395
            lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(),
396
397
            False,
        ),
398
        "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
399
400
        "AutoContrast": (lambda num_bins, height, width: None, False),
        "Equalize": (lambda num_bins, height, width: None, False),
401
402
    }

403
404
405
    def __init__(
        self,
        num_magnitude_bins: int = 31,
406
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
407
        fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
408
409
    ):
        super().__init__(interpolation=interpolation, fill=fill)
410
411
        self.num_magnitude_bins = num_magnitude_bins

412
    def forward(self, *inputs: Any) -> Any:
413
        flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
414
        height, width = get_spatial_size(image_or_video)
415
416
417

        transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)

418
        magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
419
420
421
422
423
424
425
        if magnitudes is not None:
            magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
            if signed and torch.rand(()) <= 0.5:
                magnitude *= -1
        else:
            magnitude = 0.0

426
427
        image_or_video = self._apply_image_or_video_transform(
            image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
428
        )
429
        return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
430
431
432


class AugMix(_AutoAugmentBase):
433
434
    _v1_transform_cls = _transforms.AugMix

435
    _PARTIAL_AUGMENTATION_SPACE = {
436
437
438
439
440
        "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
        "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True),
        "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
        "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
441
        "Posterize": (
442
            lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
443
444
            False,
        ),
445
        "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
446
447
        "AutoContrast": (lambda num_bins, height, width: None, False),
        "Equalize": (lambda num_bins, height, width: None, False),
448
    }
449
    _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = {
450
        **_PARTIAL_AUGMENTATION_SPACE,
451
452
453
454
        "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
        "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
455
456
457
458
459
460
461
462
463
    }

    def __init__(
        self,
        severity: int = 3,
        mixture_width: int = 3,
        chain_depth: int = -1,
        alpha: float = 1.0,
        all_ops: bool = True,
464
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
465
        fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
466
    ) -> None:
467
        super().__init__(interpolation=interpolation, fill=fill)
468
469
470
471
472
473
474
475
476
477
478
479
480
        self._PARAMETER_MAX = 10
        if not (1 <= severity <= self._PARAMETER_MAX):
            raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
        self.severity = severity
        self.mixture_width = mixture_width
        self.chain_depth = chain_depth
        self.alpha = alpha
        self.all_ops = all_ops

    def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
        # Must be on a separate method so that we can overwrite it in tests.
        return torch._sample_dirichlet(params)

481
    def forward(self, *inputs: Any) -> Any:
482
        flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
483
        height, width = get_spatial_size(orig_image_or_video)
484

485
486
        if isinstance(orig_image_or_video, torch.Tensor):
            image_or_video = orig_image_or_video
487
        else:  # isinstance(inpt, PIL.Image.Image):
488
            image_or_video = F.pil_to_tensor(orig_image_or_video)
489
490
491

        augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE

492
        orig_dims = list(image_or_video.shape)
493
        expected_ndim = 5 if isinstance(orig_image_or_video, datapoints.Video) else 4
494
        batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
495
496
        batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)

497
498
499
        # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
        # Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of
        # augmented image or video.
500
501
502
503
        m = self._sample_dirichlet(
            torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
        )

504
        # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.
505
506
        combined_weights = self._sample_dirichlet(
            torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
507
        ) * m[:, 1].reshape([batch_dims[0], -1])
508

509
        mix = m[:, 0].reshape(batch_dims) * batch
510
511
512
513
514
515
        for i in range(self.mixture_width):
            aug = batch
            depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
            for _ in range(depth):
                transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)

516
                magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
517
518
519
520
521
522
523
                if magnitudes is not None:
                    magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
                    if signed and torch.rand(()) <= 0.5:
                        magnitude *= -1
                else:
                    magnitude = 0.0

524
                aug = self._apply_image_or_video_transform(
525
                    aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
526
                )
527
            mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
528
        mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
529

530
        if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)):
531
            mix = orig_image_or_video.wrap_like(orig_image_or_video, mix)  # type: ignore[arg-type]
532
        elif isinstance(orig_image_or_video, PIL.Image.Image):
533
            mix = F.to_image_pil(mix)
534

535
        return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix)