functional_tensor.py 33.5 KB
Newer Older
vfdev's avatar
vfdev committed
1
import warnings
2
from typing import Optional, Tuple, List
vfdev's avatar
vfdev committed
3

4
import torch
5
from torch import Tensor
6
from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad
7
8


vfdev's avatar
vfdev committed
9
10
def _is_tensor_a_torch_image(x: Tensor) -> bool:
    return x.ndim >= 2
11
12


13
def _assert_image_tensor(img: Tensor) -> None:
14
15
16
17
    if not _is_tensor_a_torch_image(img):
        raise TypeError("Tensor is not a torch image.")


puhuk's avatar
puhuk committed
18
19
20
21
22
23
def _assert_threshold(img: Tensor, threshold: float) -> None:
    bound = 1 if img.is_floating_point() else 255
    if threshold > bound:
        raise TypeError("Threshold should be less than bound of img.")


24
25
26
27
28
29
30
def get_dimensions(img: Tensor) -> List[int]:
    _assert_image_tensor(img)
    channels = 1 if img.ndim == 2 else img.shape[-3]
    height, width = img.shape[-2:]
    return [channels, height, width]


31
def get_image_size(img: Tensor) -> List[int]:
32
    # Returns (w, h) of tensor image
33
34
    _assert_image_tensor(img)
    return [img.shape[-1], img.shape[-2]]
vfdev's avatar
vfdev committed
35
36


37
def get_image_num_channels(img: Tensor) -> int:
38
    _assert_image_tensor(img)
39
40
41
42
43
    if img.ndim == 2:
        return 1
    elif img.ndim > 2:
        return img.shape[-3]

44
    raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
45
46


47
48
def _max_value(dtype: torch.dtype) -> int:
    if dtype == torch.uint8:
49
        return 255
50
    elif dtype == torch.int8:
51
        return 127
52
    elif dtype == torch.int16:
53
        return 32767
54
    elif dtype == torch.int32:
55
        return 2147483647
56
    elif dtype == torch.int64:
57
        return 9223372036854775807
58
59
    else:
        return 1
60
61


62
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
63
    c = get_dimensions(img)[0]
64
    if c not in permitted:
65
        raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}")
66
67


68
69
70
71
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
    if image.dtype == dtype:
        return image

72
    if image.is_floating_point():
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

        # TODO: replace with dtype.is_floating_point when torchscript supports it
        if torch.tensor(0, dtype=dtype).is_floating_point():
            return image.to(dtype)

        # float to int
        if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
            image.dtype == torch.float64 and dtype == torch.int64
        ):
            msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
            raise RuntimeError(msg)

        # https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
        # For data in the range 0-1, (float * 255).to(uint) is only 255
        # when float is exactly 1.0.
        # `max + 1 - epsilon` provides more evenly distributed mapping of
        # ranges of floats to ints.
        eps = 1e-3
91
        max_val = float(_max_value(dtype))
92
93
94
        result = image.mul(max_val + 1.0 - eps)
        return result.to(dtype)
    else:
95
        input_max = float(_max_value(image.dtype))
96
97
98
99
100
101
102

        # int to float
        # TODO: replace with dtype.is_floating_point when torchscript supports it
        if torch.tensor(0, dtype=dtype).is_floating_point():
            image = image.to(dtype)
            return image / input_max

103
        output_max = float(_max_value(dtype))
104

105
106
107
108
109
        # int to int
        if input_max > output_max:
            # factor should be forced to int for torch jit script
            # otherwise factor is a float and image // factor can produce different results
            factor = int((input_max + 1) // (output_max + 1))
110
            image = torch.div(image, factor, rounding_mode="floor")
111
112
113
114
115
116
117
118
119
            return image.to(dtype)
        else:
            # factor should be forced to int for torch jit script
            # otherwise factor is a float and image * factor can produce different results
            factor = int((output_max + 1) // (input_max + 1))
            image = image.to(dtype)
            return image * factor


vfdev's avatar
vfdev committed
120
def vflip(img: Tensor) -> Tensor:
121
    _assert_image_tensor(img)
122

123
    return img.flip(-2)
124
125


vfdev's avatar
vfdev committed
126
def hflip(img: Tensor) -> Tensor:
127
    _assert_image_tensor(img)
128

129
    return img.flip(-1)
ekka's avatar
ekka committed
130
131


vfdev's avatar
vfdev committed
132
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
133
    _assert_image_tensor(img)
ekka's avatar
ekka committed
134

135
    _, h, w = get_dimensions(img)
136
137
138
139
140
    right = left + width
    bottom = top + height

    if left < 0 or top < 0 or right > w or bottom > h:
        padding_ltrb = [max(-left, 0), max(-top, 0), max(right - w, 0), max(bottom - h, 0)]
141
        return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
142
    return img[..., top:bottom, left:right]
143
144


145
146
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
    if img.ndim < 3:
147
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
148
    _assert_channels(img, [3])
149
150

    if num_output_channels not in (1, 3):
151
        raise ValueError("num_output_channels should be either 1 or 3")
152
153
154
155
156
157
158
159
160

    r, g, b = img.unbind(dim=-3)
    # This implementation closely follows the TF one:
    # https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
    l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
    l_img = l_img.unsqueeze(dim=-3)

    if num_output_channels == 3:
        return l_img.expand(img.shape)
161

162
    return l_img
163
164


vfdev's avatar
vfdev committed
165
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
166
    if brightness_factor < 0:
167
        raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
168

169
    _assert_image_tensor(img)
170

171
172
    _assert_channels(img, [1, 3])

173
    return _blend(img, torch.zeros_like(img), brightness_factor)
174
175


vfdev's avatar
vfdev committed
176
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
177
    if contrast_factor < 0:
178
        raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
179

180
    _assert_image_tensor(img)
181

182
    _assert_channels(img, [3, 1])
183
    c = get_dimensions(img)[0]
184
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
185
186
187
188
    if c == 3:
        mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
    else:
        mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)
189
190
191
192

    return _blend(img, mean, contrast_factor)


193
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
194
    if not (-0.5 <= hue_factor <= 0.5):
195
        raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
196

197
    if not (isinstance(img, torch.Tensor)):
198
        raise TypeError("Input img should be Tensor image")
199

200
201
    _assert_image_tensor(img)

202
    _assert_channels(img, [1, 3])
203
    if get_dimensions(img)[0] == 1:  # Match PIL behaviour
204
        return img
205

206
207
208
209
210
    orig_dtype = img.dtype
    if img.dtype == torch.uint8:
        img = img.to(dtype=torch.float32) / 255.0

    img = _rgb2hsv(img)
211
    h, s, v = img.unbind(dim=-3)
212
    h = (h + hue_factor) % 1.0
213
    img = torch.stack((h, s, v), dim=-3)
214
215
216
217
218
219
220
221
    img_hue_adj = _hsv2rgb(img)

    if orig_dtype == torch.uint8:
        img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)

    return img_hue_adj


vfdev's avatar
vfdev committed
222
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
223
    if saturation_factor < 0:
224
        raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
225

226
    _assert_image_tensor(img)
227

228
229
    _assert_channels(img, [1, 3])

230
    if get_dimensions(img)[0] == 1:  # Match PIL behaviour
231
        return img
232

233
    return _blend(img, rgb_to_grayscale(img), saturation_factor)
234
235


236
237
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
    if not isinstance(img, torch.Tensor):
238
        raise TypeError("Input img should be a Tensor.")
239

240
241
    _assert_channels(img, [1, 3])

242
    if gamma < 0:
243
        raise ValueError("Gamma should be a non-negative real number")
244
245
246
247

    result = img
    dtype = img.dtype
    if not torch.is_floating_point(img):
248
        result = convert_image_dtype(result, torch.float32)
249
250
251

    result = (gain * result ** gamma).clamp(0, 1)

252
    result = convert_image_dtype(result, dtype)
253
254
255
    return result


vfdev's avatar
vfdev committed
256
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
257
    ratio = float(ratio)
258
259
    bound = 1.0 if img1.is_floating_point() else 255.0
    return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
260
261


262
def _rgb2hsv(img: Tensor) -> Tensor:
263
    r, g, b = img.unbind(dim=-3)
264

265
266
    # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
    # src/libImaging/Convert.c#L330
267
268
    maxc = torch.max(img, dim=-3).values
    minc = torch.min(img, dim=-3).values
269
270
271
272
273
274
275
276
277
278

    # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
    # from happening in the results, because
    #   + S channel has division by `maxc`, which is zero only if `maxc = minc`
    #   + H channel has division by `(maxc - minc)`.
    #
    # Instead of overwriting NaN afterwards, we just prevent it from occuring so
    # we don't need to deal with it in case we save the NaN in a buffer in
    # backprop, if it is ever supported, but it doesn't hurt to do so.
    eqc = maxc == minc
279
280

    cr = maxc - minc
281
    # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
282
283
    ones = torch.ones_like(maxc)
    s = cr / torch.where(eqc, ones, maxc)
284
285
286
287
    # Note that `eqc => maxc = minc = r = g = b`. So the following calculation
    # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
    # would not matter what values `rc`, `gc`, and `bc` have here, and thus
    # replacing denominator with 1 when `eqc` is fine.
288
    cr_divisor = torch.where(eqc, ones, cr)
289
290
291
    rc = (maxc - r) / cr_divisor
    gc = (maxc - g) / cr_divisor
    bc = (maxc - b) / cr_divisor
292
293
294
295

    hr = (maxc == r) * (bc - gc)
    hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
    hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
296
    h = hr + hg + hb
297
    h = torch.fmod((h / 6.0 + 1.0), 1.0)
298
    return torch.stack((h, s, maxc), dim=-3)
299
300


301
def _hsv2rgb(img: Tensor) -> Tensor:
302
    h, s, v = img.unbind(dim=-3)
303
304
305
306
307
308
309
310
311
    i = torch.floor(h * 6.0)
    f = (h * 6.0) - i
    i = i.to(dtype=torch.int32)

    p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
    q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
    t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
    i = i % 6

312
    mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
313

314
315
316
317
    a1 = torch.stack((v, q, p, p, t, v), dim=-3)
    a2 = torch.stack((t, v, v, q, p, p), dim=-3)
    a3 = torch.stack((p, p, t, v, v, q), dim=-3)
    a4 = torch.stack((a1, a2, a3), dim=-4)
318

319
    return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
320
321


322
323
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
    # padding is left, right, top, bottom
324
325
326

    # crop if needed
    if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
327
328
        neg_min_padding = [-min(x, 0) for x in padding]
        crop_left, crop_right, crop_top, crop_bottom = neg_min_padding
329
        img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right]
330
331
        padding = [max(x, 0) for x in padding]

332
333
    in_sizes = img.size()

334
    _x_indices = [i for i in range(in_sizes[-1])]  # [0, 1, 2, 3, ...]
335
336
    left_indices = [i for i in range(padding[0] - 1, -1, -1)]  # e.g. [3, 2, 1, 0]
    right_indices = [-(i + 1) for i in range(padding[1])]  # e.g. [-1, -2, -3]
337
    x_indices = torch.tensor(left_indices + _x_indices + right_indices, device=img.device)
338

339
    _y_indices = [i for i in range(in_sizes[-2])]
340
341
    top_indices = [i for i in range(padding[2] - 1, -1, -1)]
    bottom_indices = [-(i + 1) for i in range(padding[3])]
342
    y_indices = torch.tensor(top_indices + _y_indices + bottom_indices, device=img.device)
343
344
345
346
347
348
349
350
351
352

    ndim = img.ndim
    if ndim == 3:
        return img[:, y_indices[:, None], x_indices[None, :]]
    elif ndim == 4:
        return img[:, :, y_indices[:, None], x_indices[None, :]]
    else:
        raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")


353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
def _parse_pad_padding(padding: List[int]) -> List[int]:
    if isinstance(padding, int):
        if torch.jit.is_scripting():
            # This maybe unreachable
            raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
        pad_left = pad_right = pad_top = pad_bottom = padding
    elif len(padding) == 1:
        pad_left = pad_right = pad_top = pad_bottom = padding[0]
    elif len(padding) == 2:
        pad_left = pad_right = padding[0]
        pad_top = pad_bottom = padding[1]
    else:
        pad_left = padding[0]
        pad_top = padding[1]
        pad_right = padding[2]
        pad_bottom = padding[3]

    return [pad_left, pad_right, pad_top, pad_bottom]


373
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
374
    _assert_image_tensor(img)
375
376
377
378
379
380
381
382
383
384
385
386

    if not isinstance(padding, (int, tuple, list)):
        raise TypeError("Got inappropriate padding arg")
    if not isinstance(fill, (int, float)):
        raise TypeError("Got inappropriate fill arg")
    if not isinstance(padding_mode, str):
        raise TypeError("Got inappropriate padding_mode arg")

    if isinstance(padding, tuple):
        padding = list(padding)

    if isinstance(padding, list) and len(padding) not in [1, 2, 4]:
387
        raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
388

389
390
    if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
        raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
391

392
    p = _parse_pad_padding(padding)
393

394
395
396
    if padding_mode == "edge":
        # remap padding_mode str
        padding_mode = "replicate"
397
398
399
    elif padding_mode == "symmetric":
        # route to another implementation
        return _pad_symmetric(img, p)
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414

    need_squeeze = False
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    out_dtype = img.dtype
    need_cast = False
    if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
        # Here we temporary cast input tensor to float
        # until pytorch issue is resolved :
        # https://github.com/pytorch/pytorch/issues/40763
        need_cast = True
        img = img.to(torch.float32)

415
    img = torch_pad(img, p, mode=padding_mode, value=float(fill))
416
417
418
419
420
421
422

    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
        img = img.to(out_dtype)

423
    return img
vfdev's avatar
vfdev committed
424
425


426
427
428
429
430
def resize(
    img: Tensor,
    size: List[int],
    interpolation: str = "bilinear",
    max_size: Optional[int] = None,
431
    antialias: Optional[bool] = None,
432
) -> Tensor:
433
    _assert_image_tensor(img)
vfdev's avatar
vfdev committed
434
435
436

    if not isinstance(size, (int, tuple, list)):
        raise TypeError("Got inappropriate size arg")
437
    if not isinstance(interpolation, str):
vfdev's avatar
vfdev committed
438
439
        raise TypeError("Got inappropriate interpolation arg")

440
    if interpolation not in ["nearest", "bilinear", "bicubic"]:
vfdev's avatar
vfdev committed
441
442
443
444
445
        raise ValueError("This interpolation mode is unsupported with Tensor input")

    if isinstance(size, tuple):
        size = list(size)

446
447
    if isinstance(size, list):
        if len(size) not in [1, 2]:
448
            raise ValueError(
449
                f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list"
450
            )
451
452
453
454
455
        if max_size is not None and len(size) != 1:
            raise ValueError(
                "max_size should only be passed if size specifies the length of the smaller edge, "
                "i.e. size should be an int or a sequence of length 1 in torchscript mode."
            )
vfdev's avatar
vfdev committed
456

457
458
459
    if antialias is None:
        antialias = False

460
461
    if antialias and interpolation not in ["bilinear", "bicubic"]:
        raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")
462

463
    _, h, w = get_dimensions(img)
vfdev's avatar
vfdev committed
464

465
466
    if isinstance(size, int) or len(size) == 1:  # specified size only for the smallest edge
        short, long = (w, h) if w <= h else (h, w)
Nicolas Hug's avatar
Nicolas Hug committed
467
        requested_new_short = size if isinstance(size, int) else size[0]
vfdev's avatar
vfdev committed
468

469
470
471
472
473
474
475
476
477
478
479
480
481
        new_short, new_long = requested_new_short, int(requested_new_short * long / short)

        if max_size is not None:
            if max_size <= requested_new_short:
                raise ValueError(
                    f"max_size = {max_size} must be strictly greater than the requested "
                    f"size for the smaller edge size = {size}"
                )
            if new_long > max_size:
                new_short, new_long = int(max_size * new_short / new_long), max_size

        new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)

482
483
484
        if (w, h) == (new_w, new_h):
            return img

485
486
487
    else:  # specified both h and w
        new_w, new_h = size[1], size[0]

vfdev's avatar
vfdev committed
488
    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
vfdev's avatar
vfdev committed
489
490

    # Define align_corners to avoid warnings
491
    align_corners = False if interpolation in ["bilinear", "bicubic"] else None
vfdev's avatar
vfdev committed
492

493
    img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners, antialias=antialias)
vfdev's avatar
vfdev committed
494

495
    if interpolation == "bicubic" and out_dtype == torch.uint8:
vfdev's avatar
vfdev committed
496
        img = img.clamp(min=0, max=255)
vfdev's avatar
vfdev committed
497

vfdev's avatar
vfdev committed
498
    img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
vfdev's avatar
vfdev committed
499
500

    return img
vfdev's avatar
vfdev committed
501
502


vfdev's avatar
vfdev committed
503
def _assert_grid_transform_inputs(
504
505
506
507
508
509
510
    img: Tensor,
    matrix: Optional[List[float]],
    interpolation: str,
    fill: Optional[List[float]],
    supported_interpolation_modes: List[str],
    coeffs: Optional[List[float]] = None,
) -> None:
511
512
513
514
515

    if not (isinstance(img, torch.Tensor)):
        raise TypeError("Input img should be Tensor")

    _assert_image_tensor(img)
vfdev's avatar
vfdev committed
516

517
    if matrix is not None and not isinstance(matrix, list):
518
        raise TypeError("Argument matrix should be a list")
vfdev's avatar
vfdev committed
519

520
    if matrix is not None and len(matrix) != 6:
vfdev's avatar
vfdev committed
521
        raise ValueError("Argument matrix should have 6 float values")
vfdev's avatar
vfdev committed
522

523
524
525
    if coeffs is not None and len(coeffs) != 8:
        raise ValueError("Argument coeffs should have 8 float values")

526
527
528
529
    if fill is not None and not isinstance(fill, (int, float, tuple, list)):
        warnings.warn("Argument fill should be either int, float, tuple or list")

    # Check fill
530
    num_channels = get_dimensions(img)[0]
531
    if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
532
533
534
535
        msg = (
            "The number of elements in 'fill' cannot broadcast to match the number of "
            "channels of the image ({} != {})"
        )
536
        raise ValueError(msg.format(len(fill), num_channels))
vfdev's avatar
vfdev committed
537

538
    if interpolation not in supported_interpolation_modes:
539
        raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
vfdev's avatar
vfdev committed
540
541


vfdev's avatar
vfdev committed
542
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
vfdev's avatar
vfdev committed
543
    need_squeeze = False
544
    # make image NCHW
vfdev's avatar
vfdev committed
545
546
547
548
549
550
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    out_dtype = img.dtype
    need_cast = False
vfdev's avatar
vfdev committed
551
    if out_dtype not in req_dtypes:
vfdev's avatar
vfdev committed
552
        need_cast = True
vfdev's avatar
vfdev committed
553
        req_dtype = req_dtypes[0]
554
555
        img = img.to(req_dtype)
    return img, need_cast, need_squeeze, out_dtype
vfdev's avatar
vfdev committed
556
557


558
def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor:
vfdev's avatar
vfdev committed
559
560
561
562
    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
vfdev's avatar
vfdev committed
563
564
565
566
        if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
            # it is better to round before cast
            img = torch.round(img)
        img = img.to(out_dtype)
vfdev's avatar
vfdev committed
567
568

    return img
vfdev's avatar
vfdev committed
569
570


571
def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[List[float]]) -> Tensor:
572

573
574
575
576
577
578
    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
        img,
        [
            grid.dtype,
        ],
    )
579
580
581
582

    if img.shape[0] > 1:
        # Apply same grid to a batch of images
        grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
583
584
585
586
587
588

    # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
    if fill is not None:
        dummy = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
        img = torch.cat((img, dummy), dim=1)

589
590
    img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)

591
592
593
594
595
596
597
    # Fill with required color
    if fill is not None:
        mask = img[:, -1:, :, :]  # N * 1 * H * W
        img = img[:, :-1, :, :]  # N * C * H * W
        mask = mask.expand_as(img)
        len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1
        fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
598
        if mode == "nearest":
599
600
601
602
603
            mask = mask < 0.5
            img[mask] = fill_img[mask]
        else:  # 'bilinear'
            img = img * mask + (1.0 - mask) * fill_img

604
605
606
607
    img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
    return img


608
def _gen_affine_grid(
609
610
611
612
613
    theta: Tensor,
    w: int,
    h: int,
    ow: int,
    oh: int,
614
615
616
617
618
619
620
621
) -> Tensor:
    # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
    # AffineGridGenerator.cpp#L18
    # Difference with AffineGridGenerator is that:
    # 1) we normalize grid values after applying theta
    # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate

    d = 0.5
622
    base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
623
624
625
626
    x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
627
628
    base_grid[..., 2].fill_(1)

629
630
    rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
    output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
631
632
633
    return output_grid.view(1, oh, ow, 2)


vfdev's avatar
vfdev committed
634
def affine(
635
    img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None
vfdev's avatar
vfdev committed
636
) -> Tensor:
637
    _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
vfdev's avatar
vfdev committed
638

639
640
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
vfdev's avatar
vfdev committed
641
    shape = img.shape
642
    # grid will be generated on the same device as theta and img
643
    grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
644
    return _apply_grid_transform(img, grid, interpolation, fill=fill)
vfdev's avatar
vfdev committed
645
646


647
def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
vfdev's avatar
vfdev committed
648

649
650
651
    # Inspired of PIL implementation:
    # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054

vfdev's avatar
vfdev committed
652
    # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
653
654
    # Points are shifted due to affine matrix torch convention about
    # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
655
656
657
658
659
660
661
662
    pts = torch.tensor(
        [
            [-0.5 * w, -0.5 * h, 1.0],
            [-0.5 * w, 0.5 * h, 1.0],
            [0.5 * w, 0.5 * h, 1.0],
            [0.5 * w, -0.5 * h, 1.0],
        ]
    )
663
664
    theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
    new_pts = torch.matmul(pts, theta.T)
vfdev's avatar
vfdev committed
665
666
667
    min_vals, _ = new_pts.min(dim=0)
    max_vals, _ = new_pts.max(dim=0)

668
669
670
671
    # shift points to [0, w] and [0, h] interval to match PIL results
    min_vals += torch.tensor((w * 0.5, h * 0.5))
    max_vals += torch.tensor((w * 0.5, h * 0.5))

672
673
674
675
676
677
    # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
    tol = 1e-4
    cmax = torch.ceil((max_vals / tol).trunc_() * tol)
    cmin = torch.floor((min_vals / tol).trunc_() * tol)
    size = cmax - cmin
    return int(size[0]), int(size[1])
vfdev's avatar
vfdev committed
678
679
680


def rotate(
681
682
683
684
685
    img: Tensor,
    matrix: List[float],
    interpolation: str = "nearest",
    expand: bool = False,
    fill: Optional[List[float]] = None,
vfdev's avatar
vfdev committed
686
) -> Tensor:
687
    _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
688
    w, h = img.shape[-1], img.shape[-2]
689
    ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h)
690
691
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
692
    # grid will be generated on the same device as theta and img
693
    grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
694
695

    return _apply_grid_transform(img, grid, interpolation, fill=fill)
696
697


698
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> Tensor:
699
700
701
702
703
704
705
    # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
    # src/libImaging/Geometry.c#L394

    #
    # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #
706
707
708
709
    theta1 = torch.tensor(
        [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
    )
    theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
710
711

    d = 0.5
712
    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
713
714
715
716
    x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
717
718
    base_grid[..., 2].fill_(1)

719
    rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
720
    output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
721
722
723
724
725
726
727
    output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))

    output_grid = output_grid1 / output_grid2 - 1.0
    return output_grid.view(1, oh, ow, 2)


def perspective(
728
    img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[List[float]] = None
729
) -> Tensor:
730
    if not (isinstance(img, torch.Tensor)):
731
        raise TypeError("Input img should be Tensor.")
732
733

    _assert_image_tensor(img)
734
735
736
737

    _assert_grid_transform_inputs(
        img,
        matrix=None,
738
739
740
        interpolation=interpolation,
        fill=fill,
        supported_interpolation_modes=["nearest", "bilinear"],
741
        coeffs=perspective_coeffs,
742
743
744
    )

    ow, oh = img.shape[-1], img.shape[-2]
745
746
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
747
    return _apply_grid_transform(img, grid, interpolation, fill=fill)
748
749
750
751
752
753
754
755
756
757
758
759
760


def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
    ksize_half = (kernel_size - 1) * 0.5

    x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
    pdf = torch.exp(-0.5 * (x / sigma).pow(2))
    kernel1d = pdf / pdf.sum()

    return kernel1d


def _get_gaussian_kernel2d(
761
    kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
762
763
764
765
766
767
768
769
) -> Tensor:
    kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
    kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
    kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
    return kernel2d


def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
770
    if not (isinstance(img, torch.Tensor)):
771
        raise TypeError(f"img should be Tensor. Got {type(img)}")
772
773

    _assert_image_tensor(img)
774
775
776
777
778

    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
    kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

779
780
781
782
783
784
    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
        img,
        [
            kernel.dtype,
        ],
    )
785
786
787
788
789
790
791
792

    # padding = (left, right, top, bottom)
    padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
    img = torch_pad(img, padding, mode="reflect")
    img = conv2d(img, kernel, groups=img.shape[-3])

    img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
    return img
793
794
795


def invert(img: Tensor) -> Tensor:
796
797

    _assert_image_tensor(img)
798
799

    if img.ndim < 3:
800
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
801
802
803
804
805
806
807
808

    _assert_channels(img, [1, 3])

    bound = torch.tensor(1 if img.is_floating_point() else 255, dtype=img.dtype, device=img.device)
    return bound - img


def posterize(img: Tensor, bits: int) -> Tensor:
809
810

    _assert_image_tensor(img)
811
812

    if img.ndim < 3:
813
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
814
    if img.dtype != torch.uint8:
815
        raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
816
817

    _assert_channels(img, [1, 3])
818
    mask = -int(2 ** (8 - bits))  # JIT-friendly for: ~(2 ** (8 - bits) - 1)
819
820
821
822
    return img & mask


def solarize(img: Tensor, threshold: float) -> Tensor:
823
824

    _assert_image_tensor(img)
825
826

    if img.ndim < 3:
827
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
828
829
830

    _assert_channels(img, [1, 3])

puhuk's avatar
puhuk committed
831
832
    _assert_threshold(img, threshold)

833
834
835
836
837
838
839
840
841
842
843
844
    inverted_img = invert(img)
    return torch.where(img >= threshold, inverted_img, img)


def _blurred_degenerate_image(img: Tensor) -> Tensor:
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32

    kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
    kernel[1, 1] = 5.0
    kernel /= kernel.sum()
    kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

845
846
847
848
849
850
    result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
        img,
        [
            kernel.dtype,
        ],
    )
851
852
853
854
855
856
857
858
859
860
861
    result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
    result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)

    result = img.clone()
    result[..., 1:-1, 1:-1] = result_tmp

    return result


def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
    if sharpness_factor < 0:
862
        raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")
863

864
    _assert_image_tensor(img)
865
866
867
868
869
870
871
872
873
874

    _assert_channels(img, [1, 3])

    if img.size(-1) <= 2 or img.size(-2) <= 2:
        return img

    return _blend(img, _blurred_degenerate_image(img), sharpness_factor)


def autocontrast(img: Tensor) -> Tensor:
875
876

    _assert_image_tensor(img)
877
878

    if img.ndim < 3:
879
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
880
881
882
883
884
885
886
887
888

    _assert_channels(img, [1, 3])

    bound = 1.0 if img.is_floating_point() else 255.0
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32

    minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
    maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
    scale = bound / (maximum - minimum)
889
890
891
    eq_idxs = torch.isfinite(scale).logical_not()
    minimum[eq_idxs] = 0
    scale[eq_idxs] = 1
892
893
894
895

    return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)


896
def _scale_channel(img_chan: Tensor) -> Tensor:
897
898
899
900
901
902
903
904
    # TODO: we should expect bincount to always be faster than histc, but this
    # isn't always the case. Once
    # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
    # block and only use bincount.
    if img_chan.is_cuda:
        hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
    else:
        hist = torch.bincount(img_chan.view(-1), minlength=256)
905
906

    nonzero_hist = hist[hist != 0]
907
    step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
908
909
910
    if step == 0:
        return img_chan

911
    lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
912
913
914
915
916
917
918
919
920
921
    lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)

    return lut[img_chan.to(torch.int64)].to(torch.uint8)


def _equalize_single_image(img: Tensor) -> Tensor:
    return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])


def equalize(img: Tensor) -> Tensor:
922
923

    _assert_image_tensor(img)
924
925

    if not (3 <= img.ndim <= 4):
926
        raise TypeError(f"Input image tensor should have 3 or 4 dimensions, but found {img.ndim}")
927
    if img.dtype != torch.uint8:
928
        raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
929
930
931
932
933
934
935

    _assert_channels(img, [1, 3])

    if img.ndim == 3:
        return _equalize_single_image(img)

    return torch.stack([_equalize_single_image(x) for x in img])
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972


def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
    _assert_image_tensor(tensor)

    if not tensor.is_floating_point():
        raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")

    if tensor.ndim < 3:
        raise ValueError(
            f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
        )

    if not inplace:
        tensor = tensor.clone()

    dtype = tensor.dtype
    mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
    std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
    if (std == 0).any():
        raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
    if mean.ndim == 1:
        mean = mean.view(-1, 1, 1)
    if std.ndim == 1:
        std = std.view(-1, 1, 1)
    tensor.sub_(mean).div_(std)
    return tensor


def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
    _assert_image_tensor(img)

    if not inplace:
        img = img.clone()

    img[..., i : i + h, j : j + w] = v
    return img