functional_tensor.py 34.7 KB
Newer Older
vfdev's avatar
vfdev committed
1
import warnings
2
from typing import Optional, Tuple, List
vfdev's avatar
vfdev committed
3

4
import torch
5
from torch import Tensor
6
from torch.jit.annotations import BroadcastingList2
7
from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad
8
9


vfdev's avatar
vfdev committed
10
11
def _is_tensor_a_torch_image(x: Tensor) -> bool:
    return x.ndim >= 2
12
13


14
def _assert_image_tensor(img: Tensor) -> None:
15
16
17
18
    if not _is_tensor_a_torch_image(img):
        raise TypeError("Tensor is not a torch image.")


puhuk's avatar
puhuk committed
19
20
21
22
23
24
def _assert_threshold(img: Tensor, threshold: float) -> None:
    bound = 1 if img.is_floating_point() else 255
    if threshold > bound:
        raise TypeError("Threshold should be less than bound of img.")


25
def get_image_size(img: Tensor) -> List[int]:
26
    # Returns (w, h) of tensor image
27
28
    _assert_image_tensor(img)
    return [img.shape[-1], img.shape[-2]]
vfdev's avatar
vfdev committed
29
30


31
def get_image_num_channels(img: Tensor) -> int:
32
33
34
35
36
    if img.ndim == 2:
        return 1
    elif img.ndim > 2:
        return img.shape[-3]

37
    raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
38
39


40
41
42
43
44
45
46
47
48
49
50
51
52
53
def _max_value(dtype: torch.dtype) -> float:
    # TODO: replace this method with torch.iinfo when it gets torchscript support.
    # https://github.com/pytorch/pytorch/issues/41492

    a = torch.tensor(2, dtype=dtype)
    signed = 1 if torch.tensor(0, dtype=dtype).is_signed() else 0
    bits = 1
    max_value = torch.tensor(-signed, dtype=torch.long)
    while True:
        next_value = a.pow(bits - signed).sub(1)
        if next_value > max_value:
            max_value = next_value
            bits *= 2
        else:
54
            break
55
56
57
    return max_value.item()


58
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
59
    c = get_image_num_channels(img)
60
    if c not in permitted:
61
        raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}")
62
63


64
65
66
67
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
    if image.dtype == dtype:
        return image

68
    if image.is_floating_point():
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

        # TODO: replace with dtype.is_floating_point when torchscript supports it
        if torch.tensor(0, dtype=dtype).is_floating_point():
            return image.to(dtype)

        # float to int
        if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
            image.dtype == torch.float64 and dtype == torch.int64
        ):
            msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
            raise RuntimeError(msg)

        # https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
        # For data in the range 0-1, (float * 255).to(uint) is only 255
        # when float is exactly 1.0.
        # `max + 1 - epsilon` provides more evenly distributed mapping of
        # ranges of floats to ints.
        eps = 1e-3
        max_val = _max_value(dtype)
        result = image.mul(max_val + 1.0 - eps)
        return result.to(dtype)
    else:
        input_max = _max_value(image.dtype)

        # int to float
        # TODO: replace with dtype.is_floating_point when torchscript supports it
        if torch.tensor(0, dtype=dtype).is_floating_point():
            image = image.to(dtype)
            return image / input_max

99
100
        output_max = _max_value(dtype)

101
102
103
104
105
        # int to int
        if input_max > output_max:
            # factor should be forced to int for torch jit script
            # otherwise factor is a float and image // factor can produce different results
            factor = int((input_max + 1) // (output_max + 1))
106
            image = torch.div(image, factor, rounding_mode="floor")
107
108
109
110
111
112
113
114
115
            return image.to(dtype)
        else:
            # factor should be forced to int for torch jit script
            # otherwise factor is a float and image * factor can produce different results
            factor = int((output_max + 1) // (input_max + 1))
            image = image.to(dtype)
            return image * factor


vfdev's avatar
vfdev committed
116
def vflip(img: Tensor) -> Tensor:
117
    _assert_image_tensor(img)
118

119
    return img.flip(-2)
120
121


vfdev's avatar
vfdev committed
122
def hflip(img: Tensor) -> Tensor:
123
    _assert_image_tensor(img)
124

125
    return img.flip(-1)
ekka's avatar
ekka committed
126
127


vfdev's avatar
vfdev committed
128
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
129
    _assert_image_tensor(img)
ekka's avatar
ekka committed
130

131
    w, h = get_image_size(img)
132
133
134
135
136
    right = left + width
    bottom = top + height

    if left < 0 or top < 0 or right > w or bottom > h:
        padding_ltrb = [max(-left, 0), max(-top, 0), max(right - w, 0), max(bottom - h, 0)]
137
        return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
138
    return img[..., top:bottom, left:right]
139
140


141
142
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
    if img.ndim < 3:
143
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
144
    _assert_channels(img, [3])
145
146

    if num_output_channels not in (1, 3):
147
        raise ValueError("num_output_channels should be either 1 or 3")
148
149
150
151
152
153
154
155
156

    r, g, b = img.unbind(dim=-3)
    # This implementation closely follows the TF one:
    # https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
    l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
    l_img = l_img.unsqueeze(dim=-3)

    if num_output_channels == 3:
        return l_img.expand(img.shape)
157

158
    return l_img
159
160


vfdev's avatar
vfdev committed
161
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
162
    if brightness_factor < 0:
163
        raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
164

165
    _assert_image_tensor(img)
166

167
168
    _assert_channels(img, [1, 3])

169
    return _blend(img, torch.zeros_like(img), brightness_factor)
170
171


vfdev's avatar
vfdev committed
172
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
173
    if contrast_factor < 0:
174
        raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
175

176
    _assert_image_tensor(img)
177

178
179
    _assert_channels(img, [3, 1])
    c = get_image_num_channels(img)
180
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
181
182
183
184
    if c == 3:
        mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
    else:
        mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)
185
186
187
188

    return _blend(img, mean, contrast_factor)


189
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
190
    if not (-0.5 <= hue_factor <= 0.5):
191
        raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
192

193
    if not (isinstance(img, torch.Tensor)):
194
        raise TypeError("Input img should be Tensor image")
195

196
197
    _assert_image_tensor(img)

198
    _assert_channels(img, [1, 3])
199
    if get_image_num_channels(img) == 1:  # Match PIL behaviour
200
        return img
201

202
203
204
205
206
    orig_dtype = img.dtype
    if img.dtype == torch.uint8:
        img = img.to(dtype=torch.float32) / 255.0

    img = _rgb2hsv(img)
207
    h, s, v = img.unbind(dim=-3)
208
    h = (h + hue_factor) % 1.0
209
    img = torch.stack((h, s, v), dim=-3)
210
211
212
213
214
215
216
217
    img_hue_adj = _hsv2rgb(img)

    if orig_dtype == torch.uint8:
        img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)

    return img_hue_adj


vfdev's avatar
vfdev committed
218
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
219
    if saturation_factor < 0:
220
        raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
221

222
    _assert_image_tensor(img)
223

224
225
226
227
    _assert_channels(img, [1, 3])

    if get_image_num_channels(img) == 1:  # Match PIL behaviour
        return img
228

229
    return _blend(img, rgb_to_grayscale(img), saturation_factor)
230
231


232
233
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
    if not isinstance(img, torch.Tensor):
234
        raise TypeError("Input img should be a Tensor.")
235

236
237
    _assert_channels(img, [1, 3])

238
    if gamma < 0:
239
        raise ValueError("Gamma should be a non-negative real number")
240
241
242
243

    result = img
    dtype = img.dtype
    if not torch.is_floating_point(img):
244
        result = convert_image_dtype(result, torch.float32)
245
246
247

    result = (gain * result ** gamma).clamp(0, 1)

248
    result = convert_image_dtype(result, dtype)
249
250
251
    return result


vfdev's avatar
vfdev committed
252
def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
253
    """DEPRECATED"""
254
    warnings.warn(
255
        "This method is deprecated and will be removed in future releases. Please, use ``F.center_crop`` instead."
256
257
    )

258
    _assert_image_tensor(img)
259
260
261

    _, image_width, image_height = img.size()
    crop_height, crop_width = output_size
vfdev's avatar
vfdev committed
262
263
264
265
266
267
268
269
    # crop_top = int(round((image_height - crop_height) / 2.))
    # Result can be different between python func and scripted func
    # Temporary workaround:
    crop_top = int((image_height - crop_height + 1) * 0.5)
    # crop_left = int(round((image_width - crop_width) / 2.))
    # Result can be different between python func and scripted func
    # Temporary workaround:
    crop_left = int((image_width - crop_width + 1) * 0.5)
270
271
272
273

    return crop(img, crop_top, crop_left, crop_height, crop_width)


vfdev's avatar
vfdev committed
274
def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
275
    """DEPRECATED"""
276
    warnings.warn(
277
        "This method is deprecated and will be removed in future releases. Please, use ``F.five_crop`` instead."
278
279
    )

280
    _assert_image_tensor(img)
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

    assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

    _, image_width, image_height = img.size()
    crop_height, crop_width = size
    if crop_width > image_width or crop_height > image_height:
        msg = "Requested crop size {} is bigger than input size {}"
        raise ValueError(msg.format(size, (image_height, image_width)))

    tl = crop(img, 0, 0, crop_width, crop_height)
    tr = crop(img, image_width - crop_width, 0, image_width, crop_height)
    bl = crop(img, 0, image_height - crop_height, crop_width, image_height)
    br = crop(img, image_width - crop_width, image_height - crop_height, image_width, image_height)
    center = center_crop(img, (crop_height, crop_width))

296
    return [tl, tr, bl, br, center]
297
298


vfdev's avatar
vfdev committed
299
def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = False) -> List[Tensor]:
300
    """DEPRECATED"""
301
    warnings.warn(
302
        "This method is deprecated and will be removed in future releases. Please, use ``F.ten_crop`` instead."
303
304
    )

305
    _assert_image_tensor(img)
306
307
308
309
310
311
312
313
314
315
316
317
318
319

    assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
    first_five = five_crop(img, size)

    if vertical_flip:
        img = vflip(img)
    else:
        img = hflip(img)

    second_five = five_crop(img, size)

    return first_five + second_five


vfdev's avatar
vfdev committed
320
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
321
    ratio = float(ratio)
322
323
    bound = 1.0 if img1.is_floating_point() else 255.0
    return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
324
325


326
def _rgb2hsv(img: Tensor) -> Tensor:
327
    r, g, b = img.unbind(dim=-3)
328

329
330
    # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
    # src/libImaging/Convert.c#L330
331
332
    maxc = torch.max(img, dim=-3).values
    minc = torch.min(img, dim=-3).values
333
334
335
336
337
338
339
340
341
342

    # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
    # from happening in the results, because
    #   + S channel has division by `maxc`, which is zero only if `maxc = minc`
    #   + H channel has division by `(maxc - minc)`.
    #
    # Instead of overwriting NaN afterwards, we just prevent it from occuring so
    # we don't need to deal with it in case we save the NaN in a buffer in
    # backprop, if it is ever supported, but it doesn't hurt to do so.
    eqc = maxc == minc
343
344

    cr = maxc - minc
345
    # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
346
347
    ones = torch.ones_like(maxc)
    s = cr / torch.where(eqc, ones, maxc)
348
349
350
351
    # Note that `eqc => maxc = minc = r = g = b`. So the following calculation
    # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
    # would not matter what values `rc`, `gc`, and `bc` have here, and thus
    # replacing denominator with 1 when `eqc` is fine.
352
    cr_divisor = torch.where(eqc, ones, cr)
353
354
355
    rc = (maxc - r) / cr_divisor
    gc = (maxc - g) / cr_divisor
    bc = (maxc - b) / cr_divisor
356
357
358
359

    hr = (maxc == r) * (bc - gc)
    hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
    hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
360
    h = hr + hg + hb
361
    h = torch.fmod((h / 6.0 + 1.0), 1.0)
362
    return torch.stack((h, s, maxc), dim=-3)
363
364


365
def _hsv2rgb(img: Tensor) -> Tensor:
366
    h, s, v = img.unbind(dim=-3)
367
368
369
370
371
372
373
374
375
    i = torch.floor(h * 6.0)
    f = (h * 6.0) - i
    i = i.to(dtype=torch.int32)

    p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
    q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
    t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
    i = i % 6

376
    mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
377

378
379
380
381
    a1 = torch.stack((v, q, p, p, t, v), dim=-3)
    a2 = torch.stack((t, v, v, q, p, p), dim=-3)
    a3 = torch.stack((p, p, t, v, v, q), dim=-3)
    a4 = torch.stack((a1, a2, a3), dim=-4)
382

383
    return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
384
385


386
387
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
    # padding is left, right, top, bottom
388
389
390

    # crop if needed
    if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
391
392
        neg_min_padding = [-min(x, 0) for x in padding]
        crop_left, crop_right, crop_top, crop_bottom = neg_min_padding
393
        img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right]
394
395
        padding = [max(x, 0) for x in padding]

396
397
    in_sizes = img.size()

398
    _x_indices = [i for i in range(in_sizes[-1])]  # [0, 1, 2, 3, ...]
399
400
    left_indices = [i for i in range(padding[0] - 1, -1, -1)]  # e.g. [3, 2, 1, 0]
    right_indices = [-(i + 1) for i in range(padding[1])]  # e.g. [-1, -2, -3]
401
    x_indices = torch.tensor(left_indices + _x_indices + right_indices, device=img.device)
402

403
    _y_indices = [i for i in range(in_sizes[-2])]
404
405
    top_indices = [i for i in range(padding[2] - 1, -1, -1)]
    bottom_indices = [-(i + 1) for i in range(padding[3])]
406
    y_indices = torch.tensor(top_indices + _y_indices + bottom_indices, device=img.device)
407
408
409
410
411
412
413
414
415
416

    ndim = img.ndim
    if ndim == 3:
        return img[:, y_indices[:, None], x_indices[None, :]]
    elif ndim == 4:
        return img[:, :, y_indices[:, None], x_indices[None, :]]
    else:
        raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")


417
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
418
    _assert_image_tensor(img)
419
420
421
422
423
424
425
426
427
428
429
430

    if not isinstance(padding, (int, tuple, list)):
        raise TypeError("Got inappropriate padding arg")
    if not isinstance(fill, (int, float)):
        raise TypeError("Got inappropriate fill arg")
    if not isinstance(padding_mode, str):
        raise TypeError("Got inappropriate padding_mode arg")

    if isinstance(padding, tuple):
        padding = list(padding)

    if isinstance(padding, list) and len(padding) not in [1, 2, 4]:
431
        raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
432

433
434
    if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
        raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
435
436
437

    if isinstance(padding, int):
        if torch.jit.is_scripting():
vfdev's avatar
vfdev committed
438
            # This maybe unreachable
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
            raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
        pad_left = pad_right = pad_top = pad_bottom = padding
    elif len(padding) == 1:
        pad_left = pad_right = pad_top = pad_bottom = padding[0]
    elif len(padding) == 2:
        pad_left = pad_right = padding[0]
        pad_top = pad_bottom = padding[1]
    else:
        pad_left = padding[0]
        pad_top = padding[1]
        pad_right = padding[2]
        pad_bottom = padding[3]

    p = [pad_left, pad_right, pad_top, pad_bottom]

454
455
456
    if padding_mode == "edge":
        # remap padding_mode str
        padding_mode = "replicate"
457
458
459
    elif padding_mode == "symmetric":
        # route to another implementation
        return _pad_symmetric(img, p)
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474

    need_squeeze = False
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    out_dtype = img.dtype
    need_cast = False
    if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
        # Here we temporary cast input tensor to float
        # until pytorch issue is resolved :
        # https://github.com/pytorch/pytorch/issues/40763
        need_cast = True
        img = img.to(torch.float32)

475
    img = torch_pad(img, p, mode=padding_mode, value=float(fill))
476
477
478
479
480
481
482

    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
        img = img.to(out_dtype)

483
    return img
vfdev's avatar
vfdev committed
484
485


486
487
488
489
490
def resize(
    img: Tensor,
    size: List[int],
    interpolation: str = "bilinear",
    max_size: Optional[int] = None,
491
    antialias: Optional[bool] = None,
492
) -> Tensor:
493
    _assert_image_tensor(img)
vfdev's avatar
vfdev committed
494
495
496

    if not isinstance(size, (int, tuple, list)):
        raise TypeError("Got inappropriate size arg")
497
    if not isinstance(interpolation, str):
vfdev's avatar
vfdev committed
498
499
        raise TypeError("Got inappropriate interpolation arg")

500
    if interpolation not in ["nearest", "bilinear", "bicubic"]:
vfdev's avatar
vfdev committed
501
502
503
504
505
        raise ValueError("This interpolation mode is unsupported with Tensor input")

    if isinstance(size, tuple):
        size = list(size)

506
507
    if isinstance(size, list):
        if len(size) not in [1, 2]:
508
            raise ValueError(
509
                f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list"
510
            )
511
512
513
514
515
        if max_size is not None and len(size) != 1:
            raise ValueError(
                "max_size should only be passed if size specifies the length of the smaller edge, "
                "i.e. size should be an int or a sequence of length 1 in torchscript mode."
            )
vfdev's avatar
vfdev committed
516

517
518
519
    if antialias is None:
        antialias = False

520
521
    if antialias and interpolation not in ["bilinear", "bicubic"]:
        raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")
522

523
    w, h = get_image_size(img)
vfdev's avatar
vfdev committed
524

525
526
    if isinstance(size, int) or len(size) == 1:  # specified size only for the smallest edge
        short, long = (w, h) if w <= h else (h, w)
Nicolas Hug's avatar
Nicolas Hug committed
527
        requested_new_short = size if isinstance(size, int) else size[0]
vfdev's avatar
vfdev committed
528

529
        if short == requested_new_short:
530
            return img
vfdev's avatar
vfdev committed
531

532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
        new_short, new_long = requested_new_short, int(requested_new_short * long / short)

        if max_size is not None:
            if max_size <= requested_new_short:
                raise ValueError(
                    f"max_size = {max_size} must be strictly greater than the requested "
                    f"size for the smaller edge size = {size}"
                )
            if new_long > max_size:
                new_short, new_long = int(max_size * new_short / new_long), max_size

        new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)

    else:  # specified both h and w
        new_w, new_h = size[1], size[0]

vfdev's avatar
vfdev committed
548
    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
vfdev's avatar
vfdev committed
549
550

    # Define align_corners to avoid warnings
551
    align_corners = False if interpolation in ["bilinear", "bicubic"] else None
vfdev's avatar
vfdev committed
552

553
    if antialias:
554
        if interpolation == "bilinear":
555
            img = torch.ops.torchvision._interpolate_bilinear2d_aa(img, [new_h, new_w], align_corners=False)
556
        elif interpolation == "bicubic":
557
            img = torch.ops.torchvision._interpolate_bicubic2d_aa(img, [new_h, new_w], align_corners=False)
558
559
    else:
        img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners)
vfdev's avatar
vfdev committed
560

561
    if interpolation == "bicubic" and out_dtype == torch.uint8:
vfdev's avatar
vfdev committed
562
        img = img.clamp(min=0, max=255)
vfdev's avatar
vfdev committed
563

vfdev's avatar
vfdev committed
564
    img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
vfdev's avatar
vfdev committed
565
566

    return img
vfdev's avatar
vfdev committed
567
568


vfdev's avatar
vfdev committed
569
def _assert_grid_transform_inputs(
570
571
572
573
574
575
576
    img: Tensor,
    matrix: Optional[List[float]],
    interpolation: str,
    fill: Optional[List[float]],
    supported_interpolation_modes: List[str],
    coeffs: Optional[List[float]] = None,
) -> None:
577
578
579
580
581

    if not (isinstance(img, torch.Tensor)):
        raise TypeError("Input img should be Tensor")

    _assert_image_tensor(img)
vfdev's avatar
vfdev committed
582

583
    if matrix is not None and not isinstance(matrix, list):
584
        raise TypeError("Argument matrix should be a list")
vfdev's avatar
vfdev committed
585

586
    if matrix is not None and len(matrix) != 6:
vfdev's avatar
vfdev committed
587
        raise ValueError("Argument matrix should have 6 float values")
vfdev's avatar
vfdev committed
588

589
590
591
    if coeffs is not None and len(coeffs) != 8:
        raise ValueError("Argument coeffs should have 8 float values")

592
593
594
595
    if fill is not None and not isinstance(fill, (int, float, tuple, list)):
        warnings.warn("Argument fill should be either int, float, tuple or list")

    # Check fill
596
    num_channels = get_image_num_channels(img)
597
    if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
598
599
600
601
        msg = (
            "The number of elements in 'fill' cannot broadcast to match the number of "
            "channels of the image ({} != {})"
        )
602
        raise ValueError(msg.format(len(fill), num_channels))
vfdev's avatar
vfdev committed
603

604
    if interpolation not in supported_interpolation_modes:
605
        raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
vfdev's avatar
vfdev committed
606
607


vfdev's avatar
vfdev committed
608
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
vfdev's avatar
vfdev committed
609
    need_squeeze = False
610
    # make image NCHW
vfdev's avatar
vfdev committed
611
612
613
614
615
616
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    out_dtype = img.dtype
    need_cast = False
vfdev's avatar
vfdev committed
617
    if out_dtype not in req_dtypes:
vfdev's avatar
vfdev committed
618
        need_cast = True
vfdev's avatar
vfdev committed
619
        req_dtype = req_dtypes[0]
620
621
        img = img.to(req_dtype)
    return img, need_cast, need_squeeze, out_dtype
vfdev's avatar
vfdev committed
622
623


624
def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor:
vfdev's avatar
vfdev committed
625
626
627
628
    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
vfdev's avatar
vfdev committed
629
630
631
632
        if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
            # it is better to round before cast
            img = torch.round(img)
        img = img.to(out_dtype)
vfdev's avatar
vfdev committed
633
634

    return img
vfdev's avatar
vfdev committed
635
636


637
def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[List[float]]) -> Tensor:
638

639
640
641
642
643
644
    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
        img,
        [
            grid.dtype,
        ],
    )
645
646
647
648

    if img.shape[0] > 1:
        # Apply same grid to a batch of images
        grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
649
650
651
652
653
654

    # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
    if fill is not None:
        dummy = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
        img = torch.cat((img, dummy), dim=1)

655
656
    img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)

657
658
659
660
661
662
663
    # Fill with required color
    if fill is not None:
        mask = img[:, -1:, :, :]  # N * 1 * H * W
        img = img[:, :-1, :, :]  # N * C * H * W
        mask = mask.expand_as(img)
        len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1
        fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
664
        if mode == "nearest":
665
666
667
668
669
            mask = mask < 0.5
            img[mask] = fill_img[mask]
        else:  # 'bilinear'
            img = img * mask + (1.0 - mask) * fill_img

670
671
672
673
    img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
    return img


674
def _gen_affine_grid(
675
676
677
678
679
    theta: Tensor,
    w: int,
    h: int,
    ow: int,
    oh: int,
680
681
682
683
684
685
686
687
) -> Tensor:
    # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
    # AffineGridGenerator.cpp#L18
    # Difference with AffineGridGenerator is that:
    # 1) we normalize grid values after applying theta
    # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate

    d = 0.5
688
    base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
689
690
691
692
    x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
693
694
    base_grid[..., 2].fill_(1)

695
696
    rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
    output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
697
698
699
    return output_grid.view(1, oh, ow, 2)


vfdev's avatar
vfdev committed
700
def affine(
701
    img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None
vfdev's avatar
vfdev committed
702
) -> Tensor:
703
    _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
vfdev's avatar
vfdev committed
704

705
706
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
vfdev's avatar
vfdev committed
707
    shape = img.shape
708
    # grid will be generated on the same device as theta and img
709
    grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
710
    return _apply_grid_transform(img, grid, interpolation, fill=fill)
vfdev's avatar
vfdev committed
711
712


713
def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
vfdev's avatar
vfdev committed
714

715
716
717
    # Inspired of PIL implementation:
    # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054

vfdev's avatar
vfdev committed
718
    # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
719
720
721
722
723
724
725
726
    pts = torch.tensor(
        [
            [-0.5 * w, -0.5 * h, 1.0],
            [-0.5 * w, 0.5 * h, 1.0],
            [0.5 * w, 0.5 * h, 1.0],
            [0.5 * w, -0.5 * h, 1.0],
        ]
    )
727
    theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3)
728
    new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2)
vfdev's avatar
vfdev committed
729
730
731
    min_vals, _ = new_pts.min(dim=0)
    max_vals, _ = new_pts.max(dim=0)

732
733
734
735
736
737
    # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
    tol = 1e-4
    cmax = torch.ceil((max_vals / tol).trunc_() * tol)
    cmin = torch.floor((min_vals / tol).trunc_() * tol)
    size = cmax - cmin
    return int(size[0]), int(size[1])
vfdev's avatar
vfdev committed
738
739
740


def rotate(
741
742
743
744
745
    img: Tensor,
    matrix: List[float],
    interpolation: str = "nearest",
    expand: bool = False,
    fill: Optional[List[float]] = None,
vfdev's avatar
vfdev committed
746
) -> Tensor:
747
    _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
748
    w, h = img.shape[-1], img.shape[-2]
749
    ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h)
750
751
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
752
    # grid will be generated on the same device as theta and img
753
    grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
754
755

    return _apply_grid_transform(img, grid, interpolation, fill=fill)
756
757


758
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> Tensor:
759
760
761
762
763
764
765
    # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
    # src/libImaging/Geometry.c#L394

    #
    # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #
766
767
768
769
    theta1 = torch.tensor(
        [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
    )
    theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
770
771

    d = 0.5
772
    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
773
774
775
776
    x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
777
778
    base_grid[..., 2].fill_(1)

779
    rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
780
    output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
781
782
783
784
785
786
787
    output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))

    output_grid = output_grid1 / output_grid2 - 1.0
    return output_grid.view(1, oh, ow, 2)


def perspective(
788
    img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[List[float]] = None
789
) -> Tensor:
790
    if not (isinstance(img, torch.Tensor)):
791
        raise TypeError("Input img should be Tensor.")
792
793

    _assert_image_tensor(img)
794
795
796
797

    _assert_grid_transform_inputs(
        img,
        matrix=None,
798
799
800
        interpolation=interpolation,
        fill=fill,
        supported_interpolation_modes=["nearest", "bilinear"],
801
        coeffs=perspective_coeffs,
802
803
804
    )

    ow, oh = img.shape[-1], img.shape[-2]
805
806
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
807
    return _apply_grid_transform(img, grid, interpolation, fill=fill)
808
809
810
811
812
813
814
815
816
817
818
819
820


def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
    ksize_half = (kernel_size - 1) * 0.5

    x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
    pdf = torch.exp(-0.5 * (x / sigma).pow(2))
    kernel1d = pdf / pdf.sum()

    return kernel1d


def _get_gaussian_kernel2d(
821
    kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
822
823
824
825
826
827
828
829
) -> Tensor:
    kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
    kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
    kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
    return kernel2d


def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
830
    if not (isinstance(img, torch.Tensor)):
831
        raise TypeError(f"img should be Tensor. Got {type(img)}")
832
833

    _assert_image_tensor(img)
834
835
836
837
838

    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
    kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

839
840
841
842
843
844
    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
        img,
        [
            kernel.dtype,
        ],
    )
845
846
847
848
849
850
851
852

    # padding = (left, right, top, bottom)
    padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
    img = torch_pad(img, padding, mode="reflect")
    img = conv2d(img, kernel, groups=img.shape[-3])

    img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
    return img
853
854
855


def invert(img: Tensor) -> Tensor:
856
857

    _assert_image_tensor(img)
858
859

    if img.ndim < 3:
860
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
861
862
863
864
865
866
867
868

    _assert_channels(img, [1, 3])

    bound = torch.tensor(1 if img.is_floating_point() else 255, dtype=img.dtype, device=img.device)
    return bound - img


def posterize(img: Tensor, bits: int) -> Tensor:
869
870

    _assert_image_tensor(img)
871
872

    if img.ndim < 3:
873
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
874
    if img.dtype != torch.uint8:
875
        raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
876
877

    _assert_channels(img, [1, 3])
878
    mask = -int(2 ** (8 - bits))  # JIT-friendly for: ~(2 ** (8 - bits) - 1)
879
880
881
882
    return img & mask


def solarize(img: Tensor, threshold: float) -> Tensor:
883
884

    _assert_image_tensor(img)
885
886

    if img.ndim < 3:
887
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
888
889
890

    _assert_channels(img, [1, 3])

puhuk's avatar
puhuk committed
891
892
    _assert_threshold(img, threshold)

893
894
895
896
897
898
899
900
901
902
903
904
    inverted_img = invert(img)
    return torch.where(img >= threshold, inverted_img, img)


def _blurred_degenerate_image(img: Tensor) -> Tensor:
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32

    kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
    kernel[1, 1] = 5.0
    kernel /= kernel.sum()
    kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

905
906
907
908
909
910
    result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
        img,
        [
            kernel.dtype,
        ],
    )
911
912
913
914
915
916
917
918
919
920
921
    result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
    result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)

    result = img.clone()
    result[..., 1:-1, 1:-1] = result_tmp

    return result


def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
    if sharpness_factor < 0:
922
        raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")
923

924
    _assert_image_tensor(img)
925
926
927
928
929
930
931
932
933
934

    _assert_channels(img, [1, 3])

    if img.size(-1) <= 2 or img.size(-2) <= 2:
        return img

    return _blend(img, _blurred_degenerate_image(img), sharpness_factor)


def autocontrast(img: Tensor) -> Tensor:
935
936

    _assert_image_tensor(img)
937
938

    if img.ndim < 3:
939
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955

    _assert_channels(img, [1, 3])

    bound = 1.0 if img.is_floating_point() else 255.0
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32

    minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
    maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
    eq_idxs = torch.where(minimum == maximum)[0]
    minimum[eq_idxs] = 0
    maximum[eq_idxs] = bound
    scale = bound / (maximum - minimum)

    return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)


956
def _scale_channel(img_chan: Tensor) -> Tensor:
957
958
959
960
961
962
963
964
    # TODO: we should expect bincount to always be faster than histc, but this
    # isn't always the case. Once
    # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
    # block and only use bincount.
    if img_chan.is_cuda:
        hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
    else:
        hist = torch.bincount(img_chan.view(-1), minlength=256)
965
966

    nonzero_hist = hist[hist != 0]
967
    step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
968
969
970
    if step == 0:
        return img_chan

971
    lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
972
973
974
975
976
977
978
979
980
981
    lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)

    return lut[img_chan.to(torch.int64)].to(torch.uint8)


def _equalize_single_image(img: Tensor) -> Tensor:
    return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])


def equalize(img: Tensor) -> Tensor:
982
983

    _assert_image_tensor(img)
984
985

    if not (3 <= img.ndim <= 4):
986
        raise TypeError(f"Input image tensor should have 3 or 4 dimensions, but found {img.ndim}")
987
    if img.dtype != torch.uint8:
988
        raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
989
990
991
992
993
994
995

    _assert_channels(img, [1, 3])

    if img.ndim == 3:
        return _equalize_single_image(img)

    return torch.stack([_equalize_single_image(x) for x in img])