functional_pil.py 11.7 KB
Newer Older
1
import numbers
2
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
3

vfdev's avatar
vfdev committed
4
import numpy as np
5
import torch
6
from PIL import Image, ImageEnhance, ImageOps
vfdev's avatar
vfdev committed
7

8
9
10
11
12
13
14
try:
    import accimage
except ImportError:
    accimage = None


@torch.jit.unused
vfdev's avatar
vfdev committed
15
def _is_pil_image(img: Any) -> bool:
16
17
18
19
20
21
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)


22
23
24
@torch.jit.unused
def get_dimensions(img: Any) -> List[int]:
    if _is_pil_image(img):
25
26
27
28
        if hasattr(img, "getbands"):
            channels = len(img.getbands())
        else:
            channels = img.channels
29
30
31
32
33
        width, height = img.size
        return [channels, height, width]
    raise TypeError(f"Unexpected type {type(img)}")


vfdev's avatar
vfdev committed
34
@torch.jit.unused
35
def get_image_size(img: Any) -> List[int]:
vfdev's avatar
vfdev committed
36
    if _is_pil_image(img):
37
        return list(img.size)
38
    raise TypeError(f"Unexpected type {type(img)}")
vfdev's avatar
vfdev committed
39
40


41
@torch.jit.unused
42
def get_image_num_channels(img: Any) -> int:
43
    if _is_pil_image(img):
44
45
46
47
        if hasattr(img, "getbands"):
            return len(img.getbands())
        else:
            return img.channels
48
    raise TypeError(f"Unexpected type {type(img)}")
49
50


51
@torch.jit.unused
52
def hflip(img: Image.Image) -> Image.Image:
53
    if not _is_pil_image(img):
54
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
55

56
    return img.transpose(Image.FLIP_LEFT_RIGHT)
57
58
59


@torch.jit.unused
60
def vflip(img: Image.Image) -> Image.Image:
61
    if not _is_pil_image(img):
62
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
63

64
    return img.transpose(Image.FLIP_TOP_BOTTOM)
65
66
67


@torch.jit.unused
68
def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
69
    if not _is_pil_image(img):
70
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
71
72
73
74
75
76
77

    enhancer = ImageEnhance.Brightness(img)
    img = enhancer.enhance(brightness_factor)
    return img


@torch.jit.unused
78
def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
79
    if not _is_pil_image(img):
80
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
81
82
83
84
85
86
87

    enhancer = ImageEnhance.Contrast(img)
    img = enhancer.enhance(contrast_factor)
    return img


@torch.jit.unused
88
def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
89
    if not _is_pil_image(img):
90
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
91
92
93
94
95
96
97

    enhancer = ImageEnhance.Color(img)
    img = enhancer.enhance(saturation_factor)
    return img


@torch.jit.unused
98
def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
99
    if not (-0.5 <= hue_factor <= 0.5):
100
        raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
101
102

    if not _is_pil_image(img):
103
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
104
105

    input_mode = img.mode
106
    if input_mode in {"L", "1", "I", "F"}:
107
108
        return img

109
    h, s, v = img.convert("HSV").split()
110
111
112

    np_h = np.array(h, dtype=np.uint8)
    # uint8 addition take cares of rotation across boundaries
113
    with np.errstate(over="ignore"):
114
        np_h += np.uint8(hue_factor * 255)
115
    h = Image.fromarray(np_h, "L")
116

117
    img = Image.merge("HSV", (h, s, v)).convert(input_mode)
118
    return img
119
120


121
@torch.jit.unused
122
123
124
125
126
127
def adjust_gamma(
    img: Image.Image,
    gamma: float,
    gain: float = 1.0,
) -> Image.Image:

128
    if not _is_pil_image(img):
129
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
130
131

    if gamma < 0:
132
        raise ValueError("Gamma should be a non-negative real number")
133
134

    input_mode = img.mode
135
    img = img.convert("RGB")
136
    gamma_map = [int((255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma)) for ele in range(256)] * 3
137
138
139
140
141
142
    img = img.point(gamma_map)  # use PIL's point-function to accelerate this part

    img = img.convert(input_mode)
    return img


143
@torch.jit.unused
144
145
146
147
def pad(
    img: Image.Image,
    padding: Union[int, List[int], Tuple[int, ...]],
    fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
148
    padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
149
150
) -> Image.Image:

151
    if not _is_pil_image(img):
152
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
153
154
155

    if not isinstance(padding, (numbers.Number, tuple, list)):
        raise TypeError("Got inappropriate padding arg")
156
    if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
157
158
159
160
161
162
163
164
        raise TypeError("Got inappropriate fill arg")
    if not isinstance(padding_mode, str):
        raise TypeError("Got inappropriate padding_mode arg")

    if isinstance(padding, list):
        padding = tuple(padding)

    if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]:
165
        raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
166
167
168
169
170
171
172
173
174

    if isinstance(padding, tuple) and len(padding) == 1:
        # Compatibility with `functional_tensor.pad`
        padding = padding[0]

    if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
        raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")

    if padding_mode == "constant":
175
        opts = _parse_fill(fill, img, name="fill")
176
177
        if img.mode == "P":
            palette = img.getpalette()
178
            image = ImageOps.expand(img, border=padding, **opts)
179
180
181
            image.putpalette(palette)
            return image

182
        return ImageOps.expand(img, border=padding, **opts)
183
184
185
186
187
188
189
190
191
192
193
194
    else:
        if isinstance(padding, int):
            pad_left = pad_right = pad_top = pad_bottom = padding
        if isinstance(padding, tuple) and len(padding) == 2:
            pad_left = pad_right = padding[0]
            pad_top = pad_bottom = padding[1]
        if isinstance(padding, tuple) and len(padding) == 4:
            pad_left = padding[0]
            pad_top = padding[1]
            pad_right = padding[2]
            pad_bottom = padding[3]

195
196
197
198
199
200
201
202
203
        p = [pad_left, pad_top, pad_right, pad_bottom]
        cropping = -np.minimum(p, 0)

        if cropping.any():
            crop_left, crop_top, crop_right, crop_bottom = cropping
            img = img.crop((crop_left, crop_top, img.width - crop_right, img.height - crop_bottom))

        pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0)

204
        if img.mode == "P":
205
206
            palette = img.getpalette()
            img = np.asarray(img)
207
            img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
208
209
210
211
212
213
214
215
216
217
218
219
220
            img = Image.fromarray(img)
            img.putpalette(palette)
            return img

        img = np.asarray(img)
        # RGB image
        if len(img.shape) == 3:
            img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
        # Grayscale image
        if len(img.shape) == 2:
            img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)

        return Image.fromarray(img)
vfdev's avatar
vfdev committed
221
222
223


@torch.jit.unused
224
225
226
227
228
229
230
231
def crop(
    img: Image.Image,
    top: int,
    left: int,
    height: int,
    width: int,
) -> Image.Image:

vfdev's avatar
vfdev committed
232
    if not _is_pil_image(img):
233
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
vfdev's avatar
vfdev committed
234
235

    return img.crop((left, top, left + width, top + height))
vfdev's avatar
vfdev committed
236
237
238


@torch.jit.unused
239
240
def resize(
    img: Image.Image,
241
    size: Union[List[int], int],
242
    interpolation: int = Image.BILINEAR,
243
244
) -> Image.Image:

vfdev's avatar
vfdev committed
245
    if not _is_pil_image(img):
246
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
247
    if not (isinstance(size, list) and len(size) == 2):
248
        raise TypeError(f"Got inappropriate size arg: {size}")
vfdev's avatar
vfdev committed
249

250
    return img.resize(tuple(size[::-1]), interpolation)
vfdev's avatar
vfdev committed
251
252
253


@torch.jit.unused
254
255
256
257
258
259
def _parse_fill(
    fill: Optional[Union[float, List[float], Tuple[float, ...]]],
    img: Image.Image,
    name: str = "fillcolor",
) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:

260
    # Process fill color for affine transforms
261
    num_channels = get_image_num_channels(img)
vfdev's avatar
vfdev committed
262
263
    if fill is None:
        fill = 0
264
265
    if isinstance(fill, (int, float)) and num_channels > 1:
        fill = tuple([fill] * num_channels)
266
    if isinstance(fill, (list, tuple)):
267
268
269
        if len(fill) != num_channels:
            msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})"
            raise ValueError(msg.format(len(fill), num_channels))
270
271

        fill = tuple(fill)
vfdev's avatar
vfdev committed
272

273
274
275
276
277
278
    if img.mode != "F":
        if isinstance(fill, (list, tuple)):
            fill = tuple(int(x) for x in fill)
        else:
            fill = int(fill)

279
    return {name: fill}
vfdev's avatar
vfdev committed
280
281
282


@torch.jit.unused
283
284
285
def affine(
    img: Image.Image,
    matrix: List[float],
286
    interpolation: int = Image.NEAREST,
287
    fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
288
289
) -> Image.Image:

vfdev's avatar
vfdev committed
290
    if not _is_pil_image(img):
291
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
vfdev's avatar
vfdev committed
292
293

    output_size = img.size
294
    opts = _parse_fill(fill, img)
295
    return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts)
vfdev's avatar
vfdev committed
296
297
298


@torch.jit.unused
299
300
301
def rotate(
    img: Image.Image,
    angle: float,
302
    interpolation: int = Image.NEAREST,
303
304
    expand: bool = False,
    center: Optional[Tuple[int, int]] = None,
305
    fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
306
307
) -> Image.Image:

vfdev's avatar
vfdev committed
308
    if not _is_pil_image(img):
309
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
vfdev's avatar
vfdev committed
310

311
    opts = _parse_fill(fill, img)
312
    return img.rotate(angle, interpolation, expand, center, **opts)
313
314
315


@torch.jit.unused
316
317
def perspective(
    img: Image.Image,
318
    perspective_coeffs: List[float],
319
    interpolation: int = Image.BICUBIC,
320
    fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
321
322
) -> Image.Image:

323
    if not _is_pil_image(img):
324
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
325

326
    opts = _parse_fill(fill, img)
327

328
    return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts)
329
330
331


@torch.jit.unused
332
def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
333
    if not _is_pil_image(img):
334
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
335
336

    if num_output_channels == 1:
337
        img = img.convert("L")
338
    elif num_output_channels == 3:
339
        img = img.convert("L")
340
341
        np_img = np.array(img, dtype=np.uint8)
        np_img = np.dstack([np_img, np_img, np_img])
342
        img = Image.fromarray(np_img, "RGB")
343
    else:
344
        raise ValueError("num_output_channels should be either 1 or 3")
345
346

    return img
347
348
349


@torch.jit.unused
350
def invert(img: Image.Image) -> Image.Image:
351
    if not _is_pil_image(img):
352
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
353
354
355
356
    return ImageOps.invert(img)


@torch.jit.unused
357
def posterize(img: Image.Image, bits: int) -> Image.Image:
358
    if not _is_pil_image(img):
359
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
360
361
362
363
    return ImageOps.posterize(img, bits)


@torch.jit.unused
364
def solarize(img: Image.Image, threshold: int) -> Image.Image:
365
    if not _is_pil_image(img):
366
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
367
368
369
370
    return ImageOps.solarize(img, threshold)


@torch.jit.unused
371
def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
372
    if not _is_pil_image(img):
373
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
374
375
376
377
378
379
380

    enhancer = ImageEnhance.Sharpness(img)
    img = enhancer.enhance(sharpness_factor)
    return img


@torch.jit.unused
381
def autocontrast(img: Image.Image) -> Image.Image:
382
    if not _is_pil_image(img):
383
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
384
385
386
387
    return ImageOps.autocontrast(img)


@torch.jit.unused
388
def equalize(img: Image.Image) -> Image.Image:
389
    if not _is_pil_image(img):
390
        raise TypeError(f"img should be PIL Image. Got {type(img)}")
391
    return ImageOps.equalize(img)