functional_pil.py 12.3 KB
Newer Older
1
import numbers
2
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
3

vfdev's avatar
vfdev committed
4
import numpy as np
5
import torch
6
from PIL import Image, ImageOps, ImageEnhance
7
from typing_extensions import Literal
vfdev's avatar
vfdev committed
8

9
10
11
12
13
14
15
try:
    import accimage
except ImportError:
    accimage = None


@torch.jit.unused
vfdev's avatar
vfdev committed
16
def _is_pil_image(img: Any) -> bool:
17
18
19
20
21
22
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)


vfdev's avatar
vfdev committed
23
@torch.jit.unused
24
def get_image_size(img: Any) -> List[int]:
vfdev's avatar
vfdev committed
25
    if _is_pil_image(img):
26
        return list(img.size)
27
    raise TypeError(f"Unexpected type {type(img)}")
vfdev's avatar
vfdev committed
28
29


30
@torch.jit.unused
31
def get_image_num_channels(img: Any) -> int:
32
    if _is_pil_image(img):
33
        return 1 if img.mode == "L" else 3
34
    raise TypeError(f"Unexpected type {type(img)}")
35
36


37
@torch.jit.unused
38
def hflip(img: Image.Image) -> Image.Image:
39
    if not _is_pil_image(img):
40
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
41
42
43
44
45

    return img.transpose(Image.FLIP_LEFT_RIGHT)


@torch.jit.unused
46
def vflip(img: Image.Image) -> Image.Image:
47
    if not _is_pil_image(img):
48
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
49
50

    return img.transpose(Image.FLIP_TOP_BOTTOM)
51
52
53


@torch.jit.unused
54
def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
55
    if not _is_pil_image(img):
56
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
57
58
59
60
61
62
63

    enhancer = ImageEnhance.Brightness(img)
    img = enhancer.enhance(brightness_factor)
    return img


@torch.jit.unused
64
def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
65
    if not _is_pil_image(img):
66
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
67
68
69
70
71
72
73

    enhancer = ImageEnhance.Contrast(img)
    img = enhancer.enhance(contrast_factor)
    return img


@torch.jit.unused
74
def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
75
    if not _is_pil_image(img):
76
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
77
78
79
80
81
82
83

    enhancer = ImageEnhance.Color(img)
    img = enhancer.enhance(saturation_factor)
    return img


@torch.jit.unused
84
def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
85
    if not (-0.5 <= hue_factor <= 0.5):
86
        raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
87
88

    if not _is_pil_image(img):
89
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
90
91

    input_mode = img.mode
92
    if input_mode in {"L", "1", "I", "F"}:
93
94
        return img

95
    h, s, v = img.convert("HSV").split()
96
97
98

    np_h = np.array(h, dtype=np.uint8)
    # uint8 addition take cares of rotation across boundaries
99
    with np.errstate(over="ignore"):
100
        np_h += np.uint8(hue_factor * 255)
101
    h = Image.fromarray(np_h, "L")
102

103
    img = Image.merge("HSV", (h, s, v)).convert(input_mode)
104
    return img
105
106


107
@torch.jit.unused
108
109
110
111
112
113
def adjust_gamma(
    img: Image.Image,
    gamma: float,
    gain: float = 1.0,
) -> Image.Image:

114
    if not _is_pil_image(img):
115
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
116
117

    if gamma < 0:
118
        raise ValueError("Gamma should be a non-negative real number")
119
120

    input_mode = img.mode
121
122
    img = img.convert("RGB")
    gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma) for ele in range(256)] * 3
123
124
125
126
127
128
    img = img.point(gamma_map)  # use PIL's point-function to accelerate this part

    img = img.convert(input_mode)
    return img


129
@torch.jit.unused
130
131
132
133
def pad(
    img: Image.Image,
    padding: Union[int, List[int], Tuple[int, ...]],
    fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
134
    padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
135
136
) -> Image.Image:

137
    if not _is_pil_image(img):
138
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
139
140
141
142
143
144
145
146
147
148
149
150

    if not isinstance(padding, (numbers.Number, tuple, list)):
        raise TypeError("Got inappropriate padding arg")
    if not isinstance(fill, (numbers.Number, str, tuple)):
        raise TypeError("Got inappropriate fill arg")
    if not isinstance(padding_mode, str):
        raise TypeError("Got inappropriate padding_mode arg")

    if isinstance(padding, list):
        padding = tuple(padding)

    if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]:
151
        raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
152
153
154
155
156
157
158
159
160

    if isinstance(padding, tuple) and len(padding) == 1:
        # Compatibility with `functional_tensor.pad`
        padding = padding[0]

    if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
        raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")

    if padding_mode == "constant":
161
        opts = _parse_fill(fill, img, name="fill")
162
163
        if img.mode == "P":
            palette = img.getpalette()
164
            image = ImageOps.expand(img, border=padding, **opts)
165
166
167
            image.putpalette(palette)
            return image

168
        return ImageOps.expand(img, border=padding, **opts)
169
170
171
172
173
174
175
176
177
178
179
180
    else:
        if isinstance(padding, int):
            pad_left = pad_right = pad_top = pad_bottom = padding
        if isinstance(padding, tuple) and len(padding) == 2:
            pad_left = pad_right = padding[0]
            pad_top = pad_bottom = padding[1]
        if isinstance(padding, tuple) and len(padding) == 4:
            pad_left = padding[0]
            pad_top = padding[1]
            pad_right = padding[2]
            pad_bottom = padding[3]

181
182
183
184
185
186
187
188
189
        p = [pad_left, pad_top, pad_right, pad_bottom]
        cropping = -np.minimum(p, 0)

        if cropping.any():
            crop_left, crop_top, crop_right, crop_bottom = cropping
            img = img.crop((crop_left, crop_top, img.width - crop_right, img.height - crop_bottom))

        pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0)

190
        if img.mode == "P":
191
192
            palette = img.getpalette()
            img = np.asarray(img)
193
            img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
194
195
196
197
198
199
200
201
202
203
204
205
206
            img = Image.fromarray(img)
            img.putpalette(palette)
            return img

        img = np.asarray(img)
        # RGB image
        if len(img.shape) == 3:
            img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
        # Grayscale image
        if len(img.shape) == 2:
            img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)

        return Image.fromarray(img)
vfdev's avatar
vfdev committed
207
208
209


@torch.jit.unused
210
211
212
213
214
215
216
217
def crop(
    img: Image.Image,
    top: int,
    left: int,
    height: int,
    width: int,
) -> Image.Image:

vfdev's avatar
vfdev committed
218
    if not _is_pil_image(img):
219
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
vfdev's avatar
vfdev committed
220
221

    return img.crop((left, top, left + width, top + height))
vfdev's avatar
vfdev committed
222
223
224


@torch.jit.unused
225
226
227
228
229
230
231
def resize(
    img: Image.Image,
    size: Union[Sequence[int], int],
    interpolation: int = Image.BILINEAR,
    max_size: Optional[int] = None,
) -> Image.Image:

vfdev's avatar
vfdev committed
232
    if not _is_pil_image(img):
233
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
vfdev's avatar
vfdev committed
234
    if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
235
        raise TypeError(f"Got inappropriate size arg: {size}")
vfdev's avatar
vfdev committed
236

237
238
239
    if isinstance(size, Sequence) and len(size) == 1:
        size = size[0]
    if isinstance(size, int):
vfdev's avatar
vfdev committed
240
        w, h = img.size
241
242
243
244
245
246
247
248
249
250
251
252
253
254

        short, long = (w, h) if w <= h else (h, w)
        new_short, new_long = size, int(size * long / short)

        if max_size is not None:
            if max_size <= size:
                raise ValueError(
                    f"max_size = {max_size} must be strictly greater than the requested "
                    f"size for the smaller edge size = {size}"
                )
            if new_long > max_size:
                new_short, new_long = int(max_size * new_short / new_long), max_size

        new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
255
256
257
258
259

        if (w, h) == (new_w, new_h):
            return img
        else:
            return img.resize((new_w, new_h), interpolation)
vfdev's avatar
vfdev committed
260
    else:
261
262
263
264
265
        if max_size is not None:
            raise ValueError(
                "max_size should only be passed if size specifies the length of the smaller edge, "
                "i.e. size should be an int or a sequence of length 1 in torchscript mode."
            )
vfdev's avatar
vfdev committed
266
        return img.resize(size[::-1], interpolation)
vfdev's avatar
vfdev committed
267
268
269


@torch.jit.unused
270
271
272
273
274
275
def _parse_fill(
    fill: Optional[Union[float, List[float], Tuple[float, ...]]],
    img: Image.Image,
    name: str = "fillcolor",
) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:

276
    # Process fill color for affine transforms
vfdev's avatar
vfdev committed
277
278
279
280
281
    num_bands = len(img.getbands())
    if fill is None:
        fill = 0
    if isinstance(fill, (int, float)) and num_bands > 1:
        fill = tuple([fill] * num_bands)
282
283
    if isinstance(fill, (list, tuple)):
        if len(fill) != num_bands:
284
            msg = "The number of elements in 'fill' does not match the number of bands of the image ({} != {})"
285
286
287
            raise ValueError(msg.format(len(fill), num_bands))

        fill = tuple(fill)
vfdev's avatar
vfdev committed
288

289
    return {name: fill}
vfdev's avatar
vfdev committed
290
291
292


@torch.jit.unused
293
294
295
296
297
298
299
def affine(
    img: Image.Image,
    matrix: List[float],
    interpolation: int = Image.NEAREST,
    fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
) -> Image.Image:

vfdev's avatar
vfdev committed
300
    if not _is_pil_image(img):
301
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
vfdev's avatar
vfdev committed
302
303

    output_size = img.size
304
    opts = _parse_fill(fill, img)
305
    return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts)
vfdev's avatar
vfdev committed
306
307
308


@torch.jit.unused
309
310
311
312
313
314
315
316
317
def rotate(
    img: Image.Image,
    angle: float,
    interpolation: int = Image.NEAREST,
    expand: bool = False,
    center: Optional[Tuple[int, int]] = None,
    fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
) -> Image.Image:

vfdev's avatar
vfdev committed
318
    if not _is_pil_image(img):
319
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
vfdev's avatar
vfdev committed
320

321
    opts = _parse_fill(fill, img)
322
    return img.rotate(angle, interpolation, expand, center, **opts)
323
324
325


@torch.jit.unused
326
327
328
329
330
331
332
def perspective(
    img: Image.Image,
    perspective_coeffs: float,
    interpolation: int = Image.BICUBIC,
    fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
) -> Image.Image:

333
    if not _is_pil_image(img):
334
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
335

336
    opts = _parse_fill(fill, img)
337
338

    return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts)
339
340
341


@torch.jit.unused
342
def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
343
    if not _is_pil_image(img):
344
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
345
346

    if num_output_channels == 1:
347
        img = img.convert("L")
348
    elif num_output_channels == 3:
349
        img = img.convert("L")
350
351
        np_img = np.array(img, dtype=np.uint8)
        np_img = np.dstack([np_img, np_img, np_img])
352
        img = Image.fromarray(np_img, "RGB")
353
    else:
354
        raise ValueError("num_output_channels should be either 1 or 3")
355
356

    return img
357
358
359


@torch.jit.unused
360
def invert(img: Image.Image) -> Image.Image:
361
    if not _is_pil_image(img):
362
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
363
364
365
366
    return ImageOps.invert(img)


@torch.jit.unused
367
def posterize(img: Image.Image, bits: int) -> Image.Image:
368
    if not _is_pil_image(img):
369
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
370
371
372
373
    return ImageOps.posterize(img, bits)


@torch.jit.unused
374
def solarize(img: Image.Image, threshold: int) -> Image.Image:
375
    if not _is_pil_image(img):
376
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
377
378
379
380
    return ImageOps.solarize(img, threshold)


@torch.jit.unused
381
def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
382
    if not _is_pil_image(img):
383
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
384
385
386
387
388
389
390

    enhancer = ImageEnhance.Sharpness(img)
    img = enhancer.enhance(sharpness_factor)
    return img


@torch.jit.unused
391
def autocontrast(img: Image.Image) -> Image.Image:
392
    if not _is_pil_image(img):
393
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
394
395
396
397
    return ImageOps.autocontrast(img)


@torch.jit.unused
398
def equalize(img: Image.Image) -> Image.Image:
399
    if not _is_pil_image(img):
400
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
401
    return ImageOps.equalize(img)