_functional_tensor.py 33.1 KB
Newer Older
vfdev's avatar
vfdev committed
1
import warnings
limm's avatar
limm committed
2
from typing import List, Optional, Tuple, Union
vfdev's avatar
vfdev committed
3

4
import torch
5
from torch import Tensor
limm's avatar
limm committed
6
from torch.nn.functional import conv2d, grid_sample, interpolate, pad as torch_pad
7
8


vfdev's avatar
vfdev committed
9
10
def _is_tensor_a_torch_image(x: Tensor) -> bool:
    return x.ndim >= 2
11
12


limm's avatar
limm committed
13
def _assert_image_tensor(img: Tensor) -> None:
14
15
16
17
    if not _is_tensor_a_torch_image(img):
        raise TypeError("Tensor is not a torch image.")


limm's avatar
limm committed
18
19
20
21
22
23
24
25
def get_dimensions(img: Tensor) -> List[int]:
    _assert_image_tensor(img)
    channels = 1 if img.ndim == 2 else img.shape[-3]
    height, width = img.shape[-2:]
    return [channels, height, width]


def get_image_size(img: Tensor) -> List[int]:
26
    # Returns (w, h) of tensor image
27
28
    _assert_image_tensor(img)
    return [img.shape[-1], img.shape[-2]]
vfdev's avatar
vfdev committed
29
30


limm's avatar
limm committed
31
32
def get_image_num_channels(img: Tensor) -> int:
    _assert_image_tensor(img)
33
34
35
36
37
    if img.ndim == 2:
        return 1
    elif img.ndim > 2:
        return img.shape[-3]

limm's avatar
limm committed
38
    raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
39

40

limm's avatar
limm committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def _max_value(dtype: torch.dtype) -> int:
    if dtype == torch.uint8:
        return 255
    elif dtype == torch.int8:
        return 127
    elif dtype == torch.int16:
        return 32767
    elif dtype == torch.int32:
        return 2147483647
    elif dtype == torch.int64:
        return 9223372036854775807
    else:
        # This is only here for completeness. This value is implicitly assumed in a lot of places so changing it is not
        # easy.
        return 1
56
57


58
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
limm's avatar
limm committed
59
    c = get_dimensions(img)[0]
60
    if c not in permitted:
limm's avatar
limm committed
61
        raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}")
62
63


64
65
66
67
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
    if image.dtype == dtype:
        return image

68
    if image.is_floating_point():
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

        # TODO: replace with dtype.is_floating_point when torchscript supports it
        if torch.tensor(0, dtype=dtype).is_floating_point():
            return image.to(dtype)

        # float to int
        if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
            image.dtype == torch.float64 and dtype == torch.int64
        ):
            msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
            raise RuntimeError(msg)

        # https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
        # For data in the range 0-1, (float * 255).to(uint) is only 255
        # when float is exactly 1.0.
        # `max + 1 - epsilon` provides more evenly distributed mapping of
        # ranges of floats to ints.
        eps = 1e-3
limm's avatar
limm committed
87
        max_val = float(_max_value(dtype))
88
89
90
        result = image.mul(max_val + 1.0 - eps)
        return result.to(dtype)
    else:
limm's avatar
limm committed
91
        input_max = float(_max_value(image.dtype))
92
93
94
95
96
97
98

        # int to float
        # TODO: replace with dtype.is_floating_point when torchscript supports it
        if torch.tensor(0, dtype=dtype).is_floating_point():
            image = image.to(dtype)
            return image / input_max

limm's avatar
limm committed
99
        output_max = float(_max_value(dtype))
100

101
102
103
104
105
        # int to int
        if input_max > output_max:
            # factor should be forced to int for torch jit script
            # otherwise factor is a float and image // factor can produce different results
            factor = int((input_max + 1) // (output_max + 1))
limm's avatar
limm committed
106
            image = torch.div(image, factor, rounding_mode="floor")
107
108
109
110
111
112
113
114
115
            return image.to(dtype)
        else:
            # factor should be forced to int for torch jit script
            # otherwise factor is a float and image * factor can produce different results
            factor = int((output_max + 1) // (input_max + 1))
            image = image.to(dtype)
            return image * factor


vfdev's avatar
vfdev committed
116
def vflip(img: Tensor) -> Tensor:
117
    _assert_image_tensor(img)
118

119
    return img.flip(-2)
120
121


vfdev's avatar
vfdev committed
122
def hflip(img: Tensor) -> Tensor:
123
    _assert_image_tensor(img)
124

125
    return img.flip(-1)
ekka's avatar
ekka committed
126
127


vfdev's avatar
vfdev committed
128
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
129
    _assert_image_tensor(img)
ekka's avatar
ekka committed
130

limm's avatar
limm committed
131
    _, h, w = get_dimensions(img)
132
133
134
135
    right = left + width
    bottom = top + height

    if left < 0 or top < 0 or right > w or bottom > h:
limm's avatar
limm committed
136
137
138
139
140
141
142
        padding_ltrb = [
            max(-left + min(0, right), 0),
            max(-top + min(0, bottom), 0),
            max(right - max(w, left), 0),
            max(bottom - max(h, top), 0),
        ]
        return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
143
    return img[..., top:bottom, left:right]
144
145


146
147
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
    if img.ndim < 3:
limm's avatar
limm committed
148
149
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
    _assert_channels(img, [1, 3])
150
151

    if num_output_channels not in (1, 3):
limm's avatar
limm committed
152
153
154
155
156
157
158
159
160
161
        raise ValueError("num_output_channels should be either 1 or 3")

    if img.shape[-3] == 3:
        r, g, b = img.unbind(dim=-3)
        # This implementation closely follows the TF one:
        # https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
        l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
        l_img = l_img.unsqueeze(dim=-3)
    else:
        l_img = img.clone()
162
163
164

    if num_output_channels == 3:
        return l_img.expand(img.shape)
165

166
    return l_img
167
168


vfdev's avatar
vfdev committed
169
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
170
    if brightness_factor < 0:
limm's avatar
limm committed
171
        raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
172

173
    _assert_image_tensor(img)
174

175
176
    _assert_channels(img, [1, 3])

177
    return _blend(img, torch.zeros_like(img), brightness_factor)
178
179


vfdev's avatar
vfdev committed
180
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
181
    if contrast_factor < 0:
limm's avatar
limm committed
182
        raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
183

184
    _assert_image_tensor(img)
185

limm's avatar
limm committed
186
187
    _assert_channels(img, [3, 1])
    c = get_dimensions(img)[0]
188
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
limm's avatar
limm committed
189
190
191
192
    if c == 3:
        mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
    else:
        mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)
193
194
195
196

    return _blend(img, mean, contrast_factor)


197
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
198
    if not (-0.5 <= hue_factor <= 0.5):
limm's avatar
limm committed
199
        raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
200

201
    if not (isinstance(img, torch.Tensor)):
limm's avatar
limm committed
202
        raise TypeError("Input img should be Tensor image")
203

204
205
    _assert_image_tensor(img)

206
    _assert_channels(img, [1, 3])
limm's avatar
limm committed
207
    if get_dimensions(img)[0] == 1:  # Match PIL behaviour
208
        return img
209

210
    orig_dtype = img.dtype
limm's avatar
limm committed
211
    img = convert_image_dtype(img, torch.float32)
212
213

    img = _rgb2hsv(img)
214
    h, s, v = img.unbind(dim=-3)
215
    h = (h + hue_factor) % 1.0
216
    img = torch.stack((h, s, v), dim=-3)
217
218
    img_hue_adj = _hsv2rgb(img)

limm's avatar
limm committed
219
    return convert_image_dtype(img_hue_adj, orig_dtype)
220
221


vfdev's avatar
vfdev committed
222
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
223
    if saturation_factor < 0:
limm's avatar
limm committed
224
        raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
225

226
    _assert_image_tensor(img)
227

limm's avatar
limm committed
228
229
230
231
    _assert_channels(img, [1, 3])

    if get_dimensions(img)[0] == 1:  # Match PIL behaviour
        return img
232

233
    return _blend(img, rgb_to_grayscale(img), saturation_factor)
234
235


236
237
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
    if not isinstance(img, torch.Tensor):
limm's avatar
limm committed
238
        raise TypeError("Input img should be a Tensor.")
239

240
241
    _assert_channels(img, [1, 3])

242
    if gamma < 0:
limm's avatar
limm committed
243
        raise ValueError("Gamma should be a non-negative real number")
244
245
246
247

    result = img
    dtype = img.dtype
    if not torch.is_floating_point(img):
248
        result = convert_image_dtype(result, torch.float32)
249

limm's avatar
limm committed
250
    result = (gain * result**gamma).clamp(0, 1)
251

252
    result = convert_image_dtype(result, dtype)
253
254
255
    return result


vfdev's avatar
vfdev committed
256
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
257
    ratio = float(ratio)
limm's avatar
limm committed
258
    bound = _max_value(img1.dtype)
259
    return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
260
261


limm's avatar
limm committed
262
def _rgb2hsv(img: Tensor) -> Tensor:
263
    r, g, b = img.unbind(dim=-3)
264

265
266
    # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
    # src/libImaging/Convert.c#L330
267
268
    maxc = torch.max(img, dim=-3).values
    minc = torch.min(img, dim=-3).values
269
270
271
272
273
274

    # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
    # from happening in the results, because
    #   + S channel has division by `maxc`, which is zero only if `maxc = minc`
    #   + H channel has division by `(maxc - minc)`.
    #
limm's avatar
limm committed
275
    # Instead of overwriting NaN afterwards, we just prevent it from occurring, so
276
277
278
    # we don't need to deal with it in case we save the NaN in a buffer in
    # backprop, if it is ever supported, but it doesn't hurt to do so.
    eqc = maxc == minc
279
280

    cr = maxc - minc
281
    # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
282
283
    ones = torch.ones_like(maxc)
    s = cr / torch.where(eqc, ones, maxc)
284
285
286
287
    # Note that `eqc => maxc = minc = r = g = b`. So the following calculation
    # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
    # would not matter what values `rc`, `gc`, and `bc` have here, and thus
    # replacing denominator with 1 when `eqc` is fine.
288
    cr_divisor = torch.where(eqc, ones, cr)
289
290
291
    rc = (maxc - r) / cr_divisor
    gc = (maxc - g) / cr_divisor
    bc = (maxc - b) / cr_divisor
292
293
294
295

    hr = (maxc == r) * (bc - gc)
    hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
    hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
limm's avatar
limm committed
296
    h = hr + hg + hb
297
    h = torch.fmod((h / 6.0 + 1.0), 1.0)
298
    return torch.stack((h, s, maxc), dim=-3)
299
300


limm's avatar
limm committed
301
def _hsv2rgb(img: Tensor) -> Tensor:
302
    h, s, v = img.unbind(dim=-3)
303
304
305
306
307
308
309
310
311
    i = torch.floor(h * 6.0)
    f = (h * 6.0) - i
    i = i.to(dtype=torch.int32)

    p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
    q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
    t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
    i = i % 6

312
    mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
313

314
315
316
317
    a1 = torch.stack((v, q, p, p, t, v), dim=-3)
    a2 = torch.stack((t, v, v, q, p, p), dim=-3)
    a3 = torch.stack((p, p, t, v, v, q), dim=-3)
    a4 = torch.stack((a1, a2, a3), dim=-4)
318

319
    return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
320
321


322
323
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
    # padding is left, right, top, bottom
324
325
326

    # crop if needed
    if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
limm's avatar
limm committed
327
328
329
        neg_min_padding = [-min(x, 0) for x in padding]
        crop_left, crop_right, crop_top, crop_bottom = neg_min_padding
        img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right]
330
331
        padding = [max(x, 0) for x in padding]

332
333
    in_sizes = img.size()

limm's avatar
limm committed
334
    _x_indices = [i for i in range(in_sizes[-1])]  # [0, 1, 2, 3, ...]
335
336
    left_indices = [i for i in range(padding[0] - 1, -1, -1)]  # e.g. [3, 2, 1, 0]
    right_indices = [-(i + 1) for i in range(padding[1])]  # e.g. [-1, -2, -3]
limm's avatar
limm committed
337
    x_indices = torch.tensor(left_indices + _x_indices + right_indices, device=img.device)
338

limm's avatar
limm committed
339
    _y_indices = [i for i in range(in_sizes[-2])]
340
341
    top_indices = [i for i in range(padding[2] - 1, -1, -1)]
    bottom_indices = [-(i + 1) for i in range(padding[3])]
limm's avatar
limm committed
342
    y_indices = torch.tensor(top_indices + _y_indices + bottom_indices, device=img.device)
343
344
345
346
347
348
349
350
351
352

    ndim = img.ndim
    if ndim == 3:
        return img[:, y_indices[:, None], x_indices[None, :]]
    elif ndim == 4:
        return img[:, :, y_indices[:, None], x_indices[None, :]]
    else:
        raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")


limm's avatar
limm committed
353
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
354
355
    if isinstance(padding, int):
        if torch.jit.is_scripting():
vfdev's avatar
vfdev committed
356
            # This maybe unreachable
357
358
359
360
361
362
363
364
365
366
367
368
369
            raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
        pad_left = pad_right = pad_top = pad_bottom = padding
    elif len(padding) == 1:
        pad_left = pad_right = pad_top = pad_bottom = padding[0]
    elif len(padding) == 2:
        pad_left = pad_right = padding[0]
        pad_top = pad_bottom = padding[1]
    else:
        pad_left = padding[0]
        pad_top = padding[1]
        pad_right = padding[2]
        pad_bottom = padding[3]

limm's avatar
limm committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    return [pad_left, pad_right, pad_top, pad_bottom]


def pad(
    img: Tensor, padding: Union[int, List[int]], fill: Optional[Union[int, float]] = 0, padding_mode: str = "constant"
) -> Tensor:
    _assert_image_tensor(img)

    if fill is None:
        fill = 0

    if not isinstance(padding, (int, tuple, list)):
        raise TypeError("Got inappropriate padding arg")
    if not isinstance(fill, (int, float)):
        raise TypeError("Got inappropriate fill arg")
    if not isinstance(padding_mode, str):
        raise TypeError("Got inappropriate padding_mode arg")

    if isinstance(padding, tuple):
        padding = list(padding)

    if isinstance(padding, list):
        # TODO: Jit is failing on loading this op when scripted and saved
        # https://github.com/pytorch/pytorch/issues/81100
        if len(padding) not in [1, 2, 4]:
            raise ValueError(
                f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
            )

    if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
        raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")

    p = _parse_pad_padding(padding)
403

404
405
406
    if padding_mode == "edge":
        # remap padding_mode str
        padding_mode = "replicate"
407
408
409
    elif padding_mode == "symmetric":
        # route to another implementation
        return _pad_symmetric(img, p)
410
411
412
413
414
415
416
417
418

    need_squeeze = False
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    out_dtype = img.dtype
    need_cast = False
    if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
limm's avatar
limm committed
419
        # Here we temporarily cast input tensor to float
420
421
422
423
424
        # until pytorch issue is resolved :
        # https://github.com/pytorch/pytorch/issues/40763
        need_cast = True
        img = img.to(torch.float32)

limm's avatar
limm committed
425
426
427
428
    if padding_mode in ("reflect", "replicate"):
        img = torch_pad(img, p, mode=padding_mode)
    else:
        img = torch_pad(img, p, mode=padding_mode, value=float(fill))
429
430
431
432
433
434
435

    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
        img = img.to(out_dtype)

436
    return img
vfdev's avatar
vfdev committed
437
438


439
440
441
442
def resize(
    img: Tensor,
    size: List[int],
    interpolation: str = "bilinear",
limm's avatar
limm committed
443
    antialias: Optional[bool] = True,
444
) -> Tensor:
445
    _assert_image_tensor(img)
vfdev's avatar
vfdev committed
446
447
448
449

    if isinstance(size, tuple):
        size = list(size)

450
451
452
    if antialias is None:
        antialias = False

453
    if antialias and interpolation not in ["bilinear", "bicubic"]:
limm's avatar
limm committed
454
455
456
457
458
        # We manually set it to False to avoid an error downstream in interpolate()
        # This behaviour is documented: the parameter is irrelevant for modes
        # that are not bilinear or bicubic. We used to raise an error here, but
        # now we don't as True is the default.
        antialias = False
459

vfdev's avatar
vfdev committed
460
    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
vfdev's avatar
vfdev committed
461
462

    # Define align_corners to avoid warnings
463
    align_corners = False if interpolation in ["bilinear", "bicubic"] else None
vfdev's avatar
vfdev committed
464

limm's avatar
limm committed
465
    img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias)
vfdev's avatar
vfdev committed
466

467
    if interpolation == "bicubic" and out_dtype == torch.uint8:
vfdev's avatar
vfdev committed
468
        img = img.clamp(min=0, max=255)
vfdev's avatar
vfdev committed
469

vfdev's avatar
vfdev committed
470
    img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
vfdev's avatar
vfdev committed
471
472

    return img
vfdev's avatar
vfdev committed
473
474


vfdev's avatar
vfdev committed
475
def _assert_grid_transform_inputs(
limm's avatar
limm committed
476
477
478
479
480
481
482
    img: Tensor,
    matrix: Optional[List[float]],
    interpolation: str,
    fill: Optional[Union[int, float, List[float]]],
    supported_interpolation_modes: List[str],
    coeffs: Optional[List[float]] = None,
) -> None:
483
484
485
486
487

    if not (isinstance(img, torch.Tensor)):
        raise TypeError("Input img should be Tensor")

    _assert_image_tensor(img)
vfdev's avatar
vfdev committed
488

489
    if matrix is not None and not isinstance(matrix, list):
490
        raise TypeError("Argument matrix should be a list")
vfdev's avatar
vfdev committed
491

492
    if matrix is not None and len(matrix) != 6:
vfdev's avatar
vfdev committed
493
        raise ValueError("Argument matrix should have 6 float values")
vfdev's avatar
vfdev committed
494

495
496
497
    if coeffs is not None and len(coeffs) != 8:
        raise ValueError("Argument coeffs should have 8 float values")

498
499
500
501
    if fill is not None and not isinstance(fill, (int, float, tuple, list)):
        warnings.warn("Argument fill should be either int, float, tuple or list")

    # Check fill
limm's avatar
limm committed
502
503
504
505
506
507
    num_channels = get_dimensions(img)[0]
    if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels:
        msg = (
            "The number of elements in 'fill' cannot broadcast to match the number of "
            "channels of the image ({} != {})"
        )
508
        raise ValueError(msg.format(len(fill), num_channels))
vfdev's avatar
vfdev committed
509

510
    if interpolation not in supported_interpolation_modes:
limm's avatar
limm committed
511
        raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
vfdev's avatar
vfdev committed
512
513


vfdev's avatar
vfdev committed
514
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
vfdev's avatar
vfdev committed
515
    need_squeeze = False
516
    # make image NCHW
vfdev's avatar
vfdev committed
517
518
519
520
521
522
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    out_dtype = img.dtype
    need_cast = False
vfdev's avatar
vfdev committed
523
    if out_dtype not in req_dtypes:
vfdev's avatar
vfdev committed
524
        need_cast = True
vfdev's avatar
vfdev committed
525
        req_dtype = req_dtypes[0]
526
527
        img = img.to(req_dtype)
    return img, need_cast, need_squeeze, out_dtype
vfdev's avatar
vfdev committed
528
529


limm's avatar
limm committed
530
def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor:
vfdev's avatar
vfdev committed
531
532
533
534
    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
vfdev's avatar
vfdev committed
535
536
537
538
        if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
            # it is better to round before cast
            img = torch.round(img)
        img = img.to(out_dtype)
vfdev's avatar
vfdev committed
539
540

    return img
vfdev's avatar
vfdev committed
541
542


limm's avatar
limm committed
543
544
545
def _apply_grid_transform(
    img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]]
) -> Tensor:
546

limm's avatar
limm committed
547
    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype])
548
549
550
551

    if img.shape[0] > 1:
        # Apply same grid to a batch of images
        grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
552
553
554

    # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
    if fill is not None:
limm's avatar
limm committed
555
556
        mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
        img = torch.cat((img, mask), dim=1)
557

558
559
    img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)

560
561
562
563
564
    # Fill with required color
    if fill is not None:
        mask = img[:, -1:, :, :]  # N * 1 * H * W
        img = img[:, :-1, :, :]  # N * C * H * W
        mask = mask.expand_as(img)
limm's avatar
limm committed
565
566
567
        fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1)
        fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
        if mode == "nearest":
568
569
570
571
572
            mask = mask < 0.5
            img[mask] = fill_img[mask]
        else:  # 'bilinear'
            img = img * mask + (1.0 - mask) * fill_img

573
574
575
576
    img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
    return img


577
def _gen_affine_grid(
limm's avatar
limm committed
578
579
580
581
582
    theta: Tensor,
    w: int,
    h: int,
    ow: int,
    oh: int,
583
584
585
586
587
588
589
590
) -> Tensor:
    # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
    # AffineGridGenerator.cpp#L18
    # Difference with AffineGridGenerator is that:
    # 1) we normalize grid values after applying theta
    # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate

    d = 0.5
591
    base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
592
593
594
595
    x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
596
597
    base_grid[..., 2].fill_(1)

598
599
    rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
    output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
600
601
602
    return output_grid.view(1, oh, ow, 2)


vfdev's avatar
vfdev committed
603
def affine(
limm's avatar
limm committed
604
605
606
607
    img: Tensor,
    matrix: List[float],
    interpolation: str = "nearest",
    fill: Optional[Union[int, float, List[float]]] = None,
vfdev's avatar
vfdev committed
608
) -> Tensor:
609
    _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
vfdev's avatar
vfdev committed
610

611
612
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
vfdev's avatar
vfdev committed
613
    shape = img.shape
614
    # grid will be generated on the same device as theta and img
615
    grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
616
    return _apply_grid_transform(img, grid, interpolation, fill=fill)
vfdev's avatar
vfdev committed
617
618


limm's avatar
limm committed
619
def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
vfdev's avatar
vfdev committed
620

621
622
623
    # Inspired of PIL implementation:
    # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054

vfdev's avatar
vfdev committed
624
    # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
limm's avatar
limm committed
625
626
627
628
629
630
631
632
633
634
635
636
    # Points are shifted due to affine matrix torch convention about
    # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
    pts = torch.tensor(
        [
            [-0.5 * w, -0.5 * h, 1.0],
            [-0.5 * w, 0.5 * h, 1.0],
            [0.5 * w, 0.5 * h, 1.0],
            [0.5 * w, -0.5 * h, 1.0],
        ]
    )
    theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
    new_pts = torch.matmul(pts, theta.T)
vfdev's avatar
vfdev committed
637
638
639
    min_vals, _ = new_pts.min(dim=0)
    max_vals, _ = new_pts.max(dim=0)

limm's avatar
limm committed
640
641
642
643
    # shift points to [0, w] and [0, h] interval to match PIL results
    min_vals += torch.tensor((w * 0.5, h * 0.5))
    max_vals += torch.tensor((w * 0.5, h * 0.5))

644
645
646
647
648
    # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
    tol = 1e-4
    cmax = torch.ceil((max_vals / tol).trunc_() * tol)
    cmin = torch.floor((min_vals / tol).trunc_() * tol)
    size = cmax - cmin
limm's avatar
limm committed
649
    return int(size[0]), int(size[1])  # w, h
vfdev's avatar
vfdev committed
650
651
652


def rotate(
limm's avatar
limm committed
653
654
655
656
657
    img: Tensor,
    matrix: List[float],
    interpolation: str = "nearest",
    expand: bool = False,
    fill: Optional[Union[int, float, List[float]]] = None,
vfdev's avatar
vfdev committed
658
) -> Tensor:
659
    _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
660
    w, h = img.shape[-1], img.shape[-2]
limm's avatar
limm committed
661
    ow, oh = _compute_affine_output_size(matrix, w, h) if expand else (w, h)
662
663
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
664
    # grid will be generated on the same device as theta and img
665
    grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
666
667

    return _apply_grid_transform(img, grid, interpolation, fill=fill)
668
669


limm's avatar
limm committed
670
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> Tensor:
671
672
673
674
675
676
677
    # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
    # src/libImaging/Geometry.c#L394

    #
    # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #
limm's avatar
limm committed
678
679
680
681
    theta1 = torch.tensor(
        [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
    )
    theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
682
683

    d = 0.5
684
    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
685
686
687
688
    x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
689
690
    base_grid[..., 2].fill_(1)

691
    rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
692
    output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
693
694
695
696
697
698
699
    output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))

    output_grid = output_grid1 / output_grid2 - 1.0
    return output_grid.view(1, oh, ow, 2)


def perspective(
limm's avatar
limm committed
700
701
702
703
    img: Tensor,
    perspective_coeffs: List[float],
    interpolation: str = "bilinear",
    fill: Optional[Union[int, float, List[float]]] = None,
704
) -> Tensor:
705
    if not (isinstance(img, torch.Tensor)):
limm's avatar
limm committed
706
        raise TypeError("Input img should be Tensor.")
707
708

    _assert_image_tensor(img)
709
710
711
712

    _assert_grid_transform_inputs(
        img,
        matrix=None,
713
714
715
        interpolation=interpolation,
        fill=fill,
        supported_interpolation_modes=["nearest", "bilinear"],
limm's avatar
limm committed
716
        coeffs=perspective_coeffs,
717
718
719
    )

    ow, oh = img.shape[-1], img.shape[-2]
720
721
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
722
    return _apply_grid_transform(img, grid, interpolation, fill=fill)
723
724


limm's avatar
limm committed
725
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor:
726
727
    ksize_half = (kernel_size - 1) * 0.5

limm's avatar
limm committed
728
    x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size, dtype=dtype, device=device)
729
730
731
732
733
734
735
    pdf = torch.exp(-0.5 * (x / sigma).pow(2))
    kernel1d = pdf / pdf.sum()

    return kernel1d


def _get_gaussian_kernel2d(
limm's avatar
limm committed
736
    kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
737
) -> Tensor:
limm's avatar
limm committed
738
739
    kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype, device)
    kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype, device)
740
741
742
743
744
    kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
    return kernel2d


def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
745
    if not (isinstance(img, torch.Tensor)):
limm's avatar
limm committed
746
        raise TypeError(f"img should be Tensor. Got {type(img)}")
747
748

    _assert_image_tensor(img)
749
750
751
752
753

    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
    kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

limm's avatar
limm committed
754
    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
755
756
757
758
759
760
761
762

    # padding = (left, right, top, bottom)
    padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
    img = torch_pad(img, padding, mode="reflect")
    img = conv2d(img, kernel, groups=img.shape[-3])

    img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
    return img
763
764
765


def invert(img: Tensor) -> Tensor:
766
767

    _assert_image_tensor(img)
768
769

    if img.ndim < 3:
limm's avatar
limm committed
770
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
771
772
773

    _assert_channels(img, [1, 3])

limm's avatar
limm committed
774
    return _max_value(img.dtype) - img
775
776
777


def posterize(img: Tensor, bits: int) -> Tensor:
778
779

    _assert_image_tensor(img)
780
781

    if img.ndim < 3:
limm's avatar
limm committed
782
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
783
    if img.dtype != torch.uint8:
limm's avatar
limm committed
784
        raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
785
786

    _assert_channels(img, [1, 3])
limm's avatar
limm committed
787
    mask = -int(2 ** (8 - bits))  # JIT-friendly for: ~(2 ** (8 - bits) - 1)
788
789
790
791
    return img & mask


def solarize(img: Tensor, threshold: float) -> Tensor:
792
793

    _assert_image_tensor(img)
794
795

    if img.ndim < 3:
limm's avatar
limm committed
796
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
797
798
799

    _assert_channels(img, [1, 3])

limm's avatar
limm committed
800
801
802
    if threshold > _max_value(img.dtype):
        raise TypeError("Threshold should be less than bound of img.")

803
804
805
806
807
808
809
810
811
812
813
814
    inverted_img = invert(img)
    return torch.where(img >= threshold, inverted_img, img)


def _blurred_degenerate_image(img: Tensor) -> Tensor:
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32

    kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
    kernel[1, 1] = 5.0
    kernel /= kernel.sum()
    kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

limm's avatar
limm committed
815
    result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
816
817
818
819
820
821
822
823
824
825
826
    result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
    result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)

    result = img.clone()
    result[..., 1:-1, 1:-1] = result_tmp

    return result


def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
    if sharpness_factor < 0:
limm's avatar
limm committed
827
        raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")
828

829
    _assert_image_tensor(img)
830
831
832
833
834
835
836
837
838
839

    _assert_channels(img, [1, 3])

    if img.size(-1) <= 2 or img.size(-2) <= 2:
        return img

    return _blend(img, _blurred_degenerate_image(img), sharpness_factor)


def autocontrast(img: Tensor) -> Tensor:
840
841

    _assert_image_tensor(img)
842
843

    if img.ndim < 3:
limm's avatar
limm committed
844
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
845
846
847

    _assert_channels(img, [1, 3])

limm's avatar
limm committed
848
    bound = _max_value(img.dtype)
849
850
851
852
853
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32

    minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
    maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
    scale = bound / (maximum - minimum)
limm's avatar
limm committed
854
855
856
    eq_idxs = torch.isfinite(scale).logical_not()
    minimum[eq_idxs] = 0
    scale[eq_idxs] = 1
857
858
859
860

    return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)


limm's avatar
limm committed
861
def _scale_channel(img_chan: Tensor) -> Tensor:
862
863
864
865
866
867
868
    # TODO: we should expect bincount to always be faster than histc, but this
    # isn't always the case. Once
    # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
    # block and only use bincount.
    if img_chan.is_cuda:
        hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
    else:
limm's avatar
limm committed
869
        hist = torch.bincount(img_chan.reshape(-1), minlength=256)
870
871

    nonzero_hist = hist[hist != 0]
limm's avatar
limm committed
872
    step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
873
874
875
    if step == 0:
        return img_chan

limm's avatar
limm committed
876
    lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
877
878
879
880
881
882
883
884
885
886
    lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)

    return lut[img_chan.to(torch.int64)].to(torch.uint8)


def _equalize_single_image(img: Tensor) -> Tensor:
    return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])


def equalize(img: Tensor) -> Tensor:
887
888

    _assert_image_tensor(img)
889
890

    if not (3 <= img.ndim <= 4):
limm's avatar
limm committed
891
        raise TypeError(f"Input image tensor should have 3 or 4 dimensions, but found {img.ndim}")
892
    if img.dtype != torch.uint8:
limm's avatar
limm committed
893
        raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
894
895
896
897
898
899
900

    _assert_channels(img, [1, 3])

    if img.ndim == 3:
        return _equalize_single_image(img)

    return torch.stack([_equalize_single_image(x) for x in img])
limm's avatar
limm committed
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960


def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
    _assert_image_tensor(tensor)

    if not tensor.is_floating_point():
        raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")

    if tensor.ndim < 3:
        raise ValueError(
            f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
        )

    if not inplace:
        tensor = tensor.clone()

    dtype = tensor.dtype
    mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
    std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
    if (std == 0).any():
        raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
    if mean.ndim == 1:
        mean = mean.view(-1, 1, 1)
    if std.ndim == 1:
        std = std.view(-1, 1, 1)
    return tensor.sub_(mean).div_(std)


def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
    _assert_image_tensor(img)

    if not inplace:
        img = img.clone()

    img[..., i : i + h, j : j + w] = v
    return img


def _create_identity_grid(size: List[int]) -> Tensor:
    hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size]
    grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij")
    return torch.stack([grid_x, grid_y], -1).unsqueeze(0)  # 1 x H x W x 2


def elastic_transform(
    img: Tensor,
    displacement: Tensor,
    interpolation: str = "bilinear",
    fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:

    if not (isinstance(img, torch.Tensor)):
        raise TypeError(f"img should be Tensor. Got {type(img)}")

    size = list(img.shape[-2:])
    displacement = displacement.to(img.device)

    identity_grid = _create_identity_grid(size)
    grid = identity_grid.to(img.device) + displacement
    return _apply_grid_transform(img, grid, interpolation, fill)