"scripts/vscode:/vscode.git/clone" did not exist on "8ec850b81dfee0a52ed09caf07a1ea9670429307"
autoaugment.py 12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
"""AutoAugment data augmentation policy for ImageNet.

Implements the fixed AutoAugment data augmentation policy for ImageNet
provided in Appendix A, Table 9 in reference [1]. Does not include any
of the search code.

Reference:
[1] https://arxiv.org/abs/1805.09501

Code adapted from:
https://github.com/DeepVoltaire/AutoAugment
"""

import random

import numpy as np
from PIL import Image
from PIL import ImageEnhance
from PIL import ImageOps

_MAX_LEVEL = 10  # Maximum integer strength of an augmentation, if applicable.


class ImageNetPolicy:
    """Definition of an ImageNetPolicy.

    Implements a fixed AutoAugment data augmentation policy targeted at
    ImageNet training by randomly applying at runtime one of the 25 pre-defined
    data augmentation sub-policies provided in Reference [1].

    Usage example as a Pytorch Transform:
    >>> transform=transforms.Compose([transforms.Resize(256),
    >>>                               ImageNetPolicy(),
    >>>                               transforms.ToTensor()])
    """

    def __init__(self, fillcolor=(128, 128, 128)):
        """Initialize an ImageNetPolicy.

        Args:
            fillcolor (tuple): RGB color components of the color to be used for
            filling when needed (default: (128, 128, 128), which
            corresponds to gray).
        """
        # Instantiate a list of sub-policies.
        # Each entry of the list is a SubPolicy which consists of
        # two augmentation operations,
        # each of those parametrized as operation, probability, magnitude.
        # Those two operations are applied sequentially on the image upon call.
        self.policies = [
            SubPolicy("posterize", 0.4, 8, "rotate", 0.6, 9, fillcolor),
            SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
            SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
            SubPolicy("posterize", 0.6, 7, "posterize", 0.6, 6, fillcolor),
            SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
            SubPolicy("equalize", 0.4, 4, "rotate", 0.8, 8, fillcolor),
            SubPolicy("solarize", 0.6, 3, "equalize", 0.6, 7, fillcolor),
            SubPolicy("posterize", 0.8, 5, "equalize", 1.0, 2, fillcolor),
            SubPolicy("rotate", 0.2, 3, "solarize", 0.6, 8, fillcolor),
            SubPolicy("equalize", 0.6, 8, "posterize", 0.4, 6, fillcolor),
            SubPolicy("rotate", 0.8, 8, "color", 0.4, 0, fillcolor),
            SubPolicy("rotate", 0.4, 9, "equalize", 0.6, 2, fillcolor),
            SubPolicy("equalize", 0.0, 7, "equalize", 0.8, 8, fillcolor),
            SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
            SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
            SubPolicy("rotate", 0.8, 8, "color", 1.0, 2, fillcolor),
            SubPolicy("color", 0.8, 8, "solarize", 0.8, 7, fillcolor),
            SubPolicy("sharpness", 0.4, 7, "invert", 0.6, 8, fillcolor),
            SubPolicy("shearX", 0.6, 5, "equalize", 1.0, 9, fillcolor),
            SubPolicy("color", 0.4, 0, "equalize", 0.6, 3, fillcolor),
            SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
            SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
            SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
            SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
            SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
        ]

    def __call__(self, img):
        """Define call method for ImageNetPolicy class."""
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

    def __repr__(self):
        """Define repr method for ImageNetPolicy class."""
        return "ImageNetPolicy"


class SubPolicy:
    """Definition of a SubPolicy.

    A SubPolicy consists of two augmentation operations,
    each of those parametrized as operation, probability, magnitude.
    The two operations are applied sequentially on the image upon call.
    """

    def __init__(
        self,
        operation1,
        probability1,
        magnitude_idx1,
        operation2,
        probability2,
        magnitude_idx2,
        fillcolor,
    ):
        """Initialize a SubPolicy.

        Args:
            operation1 (str): Key specifying the first augmentation operation.
            There are fourteen key values altogether (see supported_ops below
            listing supported operations). probability1 (float): Probability
            within [0., 1.] of applying the first augmentation operation.
            magnitude_idx1 (int): Integer specifiying the strength of the first
            operation as an index further used to derive the magnitude from a
            range of possible values.
            operation2 (str): Key specifying the second augmentation operation.
            probability2 (float): Probability within [0., 1.] of applying the
            second augmentation operation.
            magnitude_idx2 (int): Integer specifiying the strength of the
            second operation as an index further used to derive the magnitude
            from a range of possible values.
            fillcolor (tuple): RGB color components of the color to be used for
            filling.
        Returns:
        """
        # List of supported operations for operation1 and operation2.
        supported_ops = [
            "shearX",
            "shearY",
            "translateX",
            "translateY",
            "rotate",
            "color",
            "posterize",
            "solarize",
            "contrast",
            "sharpness",
            "brightness",
            "autocontrast",
            "equalize",
            "invert",
        ]
        assert (operation1 in supported_ops) and (
            operation2 in supported_ops
        ), "SubPolicy:one of oper1 or oper2 refers to an unsupported operation."

        assert (
            0.0 <= probability1 <= 1.0 and 0.0 <= probability2 <= 1.0
        ), "SubPolicy: prob1 and prob2 should be within [0., 1.]."

        assert (
            isinstance(magnitude_idx1, int) and 0 <= magnitude_idx1 <= 10
        ), "SubPolicy: idx1 should be specified as an integer within [0, 10]."

        assert (
            isinstance(magnitude_idx2, int) and 0 <= magnitude_idx2 <= 10
        ), "SubPolicy: idx2 should be specified as an integer within [0, 10]."

        # Define a dictionary where each key refers to a specific type of
        # augmentation and the corresponding value is a range of ten possible
        # magnitude values for that augmentation.
        num_levels = _MAX_LEVEL + 1
        ranges = {
            "shearX": np.linspace(0, 0.3, num_levels),
            "shearY": np.linspace(0, 0.3, num_levels),
            "translateX": np.linspace(0, 150 / 331, num_levels),
            "translateY": np.linspace(0, 150 / 331, num_levels),
            "rotate": np.linspace(0, 30, num_levels),
            "color": np.linspace(0.0, 0.9, num_levels),
            "posterize": np.round(np.linspace(8, 4, num_levels), 0).astype(
                np.int
            ),
            "solarize": np.linspace(256, 0, num_levels),  # range [0, 256]
            "contrast": np.linspace(0.0, 0.9, num_levels),
            "sharpness": np.linspace(0.0, 0.9, num_levels),
            "brightness": np.linspace(0.0, 0.9, num_levels),
            "autocontrast": [0]
            * num_levels,  # This augmentation doesn't use magnitude parameter.
            "equalize": [0]
            * num_levels,  # This augmentation doesn't use magnitude parameter.
            "invert": [0]
            * num_levels,  # This augmentation doesn't use magnitude parameter.
        }

        def rotate_with_fill(img, magnitude):
            """Define rotation transformation with fill.

            The input image is first rotated, then it is blended together with
            a gray mask of the same size. Note that fillcolor as defined
            elsewhere in this module doesn't apply here.

            Args:
                magnitude (float): rotation angle in degrees.
            Returns:
                rotated_filled (PIL Image): rotated image with gray filling for
                disoccluded areas unveiled by the rotation.
            """
            rotated = img.convert("RGBA").rotate(magnitude)
            rotated_filled = Image.composite(
                rotated, Image.new("RGBA", rotated.size, (128,) * 4), rotated
            )
            return rotated_filled.convert(img.mode)

        # Define a dictionary of augmentation functions where each key refers
        # to a specific type of augmentation and the corresponding value defines
        # the augmentation itself using a lambda function.
        # pylint: disable=unnecessary-lambda
        func_dict = {
            "shearX": lambda img, magnitude: img.transform(
                img.size,
                Image.AFFINE,
                (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
                Image.BICUBIC,
                fillcolor=fillcolor,
            ),
            "shearY": lambda img, magnitude: img.transform(
                img.size,
                Image.AFFINE,
                (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
                Image.BICUBIC,
                fillcolor=fillcolor,
            ),
            "translateX": lambda img, magnitude: img.transform(
                img.size,
                Image.AFFINE,
                (
                    1,
                    0,
                    magnitude * img.size[0] * random.choice([-1, 1]),
                    0,
                    1,
                    0,
                ),
                fillcolor=fillcolor,
            ),
            "translateY": lambda img, magnitude: img.transform(
                img.size,
                Image.AFFINE,
                (
                    1,
                    0,
                    0,
                    0,
                    1,
                    magnitude * img.size[1] * random.choice([-1, 1]),
                ),
                fillcolor=fillcolor,
            ),
            "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
            "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
                1 + magnitude * random.choice([-1, 1])
            ),
            "posterize": lambda img, magnitude: ImageOps.posterize(
                img, magnitude
            ),
            "solarize": lambda img, magnitude: ImageOps.solarize(
                img, magnitude
            ),
            "contrast": lambda img, magnitude: ImageEnhance.Contrast(
                img
            ).enhance(1 + magnitude * random.choice([-1, 1])),
            "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(
                img
            ).enhance(1 + magnitude * random.choice([-1, 1])),
            "brightness": lambda img, magnitude: ImageEnhance.Brightness(
                img
            ).enhance(1 + magnitude * random.choice([-1, 1])),
            "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
            "equalize": lambda img, magnitude: ImageOps.equalize(img),
            "invert": lambda img, magnitude: ImageOps.invert(img),
        }

        # Store probability, function and magnitude of the first augmentation
        # for the sub-policy.
        self.probability1 = probability1
        self.operation1 = func_dict[operation1]
        self.magnitude1 = ranges[operation1][magnitude_idx1]

        # Store probability, function and magnitude of the second augmentation
        # for the sub-policy.
        self.probability2 = probability2
        self.operation2 = func_dict[operation2]
        self.magnitude2 = ranges[operation2][magnitude_idx2]

    def __call__(self, img):
        """Define call method for SubPolicy class."""
        # Randomly apply operation 1.
        if random.random() < self.probability1:
            img = self.operation1(img, self.magnitude1)

        # Randomly apply operation 2.
        if random.random() < self.probability2:
            img = self.operation2(img, self.magnitude2)

        return img