functional_tensor.py 33 KB
Newer Older
vfdev's avatar
vfdev committed
1
import warnings
2
from typing import List, Optional, Tuple, Union
vfdev's avatar
vfdev committed
3

4
import torch
5
from torch import Tensor
6
from torch.nn.functional import conv2d, grid_sample, interpolate, pad as torch_pad
7
8


vfdev's avatar
vfdev committed
9
10
def _is_tensor_a_torch_image(x: Tensor) -> bool:
    return x.ndim >= 2
11
12


13
def _assert_image_tensor(img: Tensor) -> None:
14
15
16
17
    if not _is_tensor_a_torch_image(img):
        raise TypeError("Tensor is not a torch image.")


puhuk's avatar
puhuk committed
18
19
20
21
22
23
def _assert_threshold(img: Tensor, threshold: float) -> None:
    bound = 1 if img.is_floating_point() else 255
    if threshold > bound:
        raise TypeError("Threshold should be less than bound of img.")


24
25
26
27
28
29
30
def get_dimensions(img: Tensor) -> List[int]:
    _assert_image_tensor(img)
    channels = 1 if img.ndim == 2 else img.shape[-3]
    height, width = img.shape[-2:]
    return [channels, height, width]


31
def get_image_size(img: Tensor) -> List[int]:
32
    # Returns (w, h) of tensor image
33
34
    _assert_image_tensor(img)
    return [img.shape[-1], img.shape[-2]]
vfdev's avatar
vfdev committed
35
36


37
def get_image_num_channels(img: Tensor) -> int:
38
    _assert_image_tensor(img)
39
40
41
42
43
    if img.ndim == 2:
        return 1
    elif img.ndim > 2:
        return img.shape[-3]

44
    raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
45
46


47
48
def _max_value(dtype: torch.dtype) -> int:
    if dtype == torch.uint8:
49
        return 255
50
    elif dtype == torch.int8:
51
        return 127
52
    elif dtype == torch.int16:
53
        return 32767
54
    elif dtype == torch.int32:
55
        return 2147483647
56
    elif dtype == torch.int64:
57
        return 9223372036854775807
58
59
    else:
        return 1
60
61


62
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
63
    c = get_dimensions(img)[0]
64
    if c not in permitted:
65
        raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}")
66
67


68
69
70
71
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
    if image.dtype == dtype:
        return image

72
    if image.is_floating_point():
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

        # TODO: replace with dtype.is_floating_point when torchscript supports it
        if torch.tensor(0, dtype=dtype).is_floating_point():
            return image.to(dtype)

        # float to int
        if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
            image.dtype == torch.float64 and dtype == torch.int64
        ):
            msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
            raise RuntimeError(msg)

        # https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
        # For data in the range 0-1, (float * 255).to(uint) is only 255
        # when float is exactly 1.0.
        # `max + 1 - epsilon` provides more evenly distributed mapping of
        # ranges of floats to ints.
        eps = 1e-3
91
        max_val = float(_max_value(dtype))
92
93
94
        result = image.mul(max_val + 1.0 - eps)
        return result.to(dtype)
    else:
95
        input_max = float(_max_value(image.dtype))
96
97
98
99
100
101
102

        # int to float
        # TODO: replace with dtype.is_floating_point when torchscript supports it
        if torch.tensor(0, dtype=dtype).is_floating_point():
            image = image.to(dtype)
            return image / input_max

103
        output_max = float(_max_value(dtype))
104

105
106
107
108
109
        # int to int
        if input_max > output_max:
            # factor should be forced to int for torch jit script
            # otherwise factor is a float and image // factor can produce different results
            factor = int((input_max + 1) // (output_max + 1))
110
            image = torch.div(image, factor, rounding_mode="floor")
111
112
113
114
115
116
117
118
119
            return image.to(dtype)
        else:
            # factor should be forced to int for torch jit script
            # otherwise factor is a float and image * factor can produce different results
            factor = int((output_max + 1) // (input_max + 1))
            image = image.to(dtype)
            return image * factor


vfdev's avatar
vfdev committed
120
def vflip(img: Tensor) -> Tensor:
121
    _assert_image_tensor(img)
122

123
    return img.flip(-2)
124
125


vfdev's avatar
vfdev committed
126
def hflip(img: Tensor) -> Tensor:
127
    _assert_image_tensor(img)
128

129
    return img.flip(-1)
ekka's avatar
ekka committed
130
131


vfdev's avatar
vfdev committed
132
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
133
    _assert_image_tensor(img)
ekka's avatar
ekka committed
134

135
    _, h, w = get_dimensions(img)
136
137
138
139
140
    right = left + width
    bottom = top + height

    if left < 0 or top < 0 or right > w or bottom > h:
        padding_ltrb = [max(-left, 0), max(-top, 0), max(right - w, 0), max(bottom - h, 0)]
141
        return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
142
    return img[..., top:bottom, left:right]
143
144


145
146
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
    if img.ndim < 3:
147
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
148
    _assert_channels(img, [1, 3])
149
150

    if num_output_channels not in (1, 3):
151
        raise ValueError("num_output_channels should be either 1 or 3")
152

153
154
155
156
157
158
159
160
    if img.shape[-3] == 3:
        r, g, b = img.unbind(dim=-3)
        # This implementation closely follows the TF one:
        # https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
        l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
        l_img = l_img.unsqueeze(dim=-3)
    else:
        l_img = img.clone()
161
162
163

    if num_output_channels == 3:
        return l_img.expand(img.shape)
164

165
    return l_img
166
167


vfdev's avatar
vfdev committed
168
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
169
    if brightness_factor < 0:
170
        raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
171

172
    _assert_image_tensor(img)
173

174
175
    _assert_channels(img, [1, 3])

176
    return _blend(img, torch.zeros_like(img), brightness_factor)
177
178


vfdev's avatar
vfdev committed
179
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
180
    if contrast_factor < 0:
181
        raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
182

183
    _assert_image_tensor(img)
184

185
    _assert_channels(img, [3, 1])
186
    c = get_dimensions(img)[0]
187
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
188
189
190
191
    if c == 3:
        mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
    else:
        mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)
192
193
194
195

    return _blend(img, mean, contrast_factor)


196
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
197
    if not (-0.5 <= hue_factor <= 0.5):
198
        raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
199

200
    if not (isinstance(img, torch.Tensor)):
201
        raise TypeError("Input img should be Tensor image")
202

203
204
    _assert_image_tensor(img)

205
    _assert_channels(img, [1, 3])
206
    if get_dimensions(img)[0] == 1:  # Match PIL behaviour
207
        return img
208

209
210
211
212
213
    orig_dtype = img.dtype
    if img.dtype == torch.uint8:
        img = img.to(dtype=torch.float32) / 255.0

    img = _rgb2hsv(img)
214
    h, s, v = img.unbind(dim=-3)
215
    h = (h + hue_factor) % 1.0
216
    img = torch.stack((h, s, v), dim=-3)
217
218
219
220
221
222
223
224
    img_hue_adj = _hsv2rgb(img)

    if orig_dtype == torch.uint8:
        img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)

    return img_hue_adj


vfdev's avatar
vfdev committed
225
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
226
    if saturation_factor < 0:
227
        raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
228

229
    _assert_image_tensor(img)
230

231
232
    _assert_channels(img, [1, 3])

233
    if get_dimensions(img)[0] == 1:  # Match PIL behaviour
234
        return img
235

236
    return _blend(img, rgb_to_grayscale(img), saturation_factor)
237
238


239
240
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
    if not isinstance(img, torch.Tensor):
241
        raise TypeError("Input img should be a Tensor.")
242

243
244
    _assert_channels(img, [1, 3])

245
    if gamma < 0:
246
        raise ValueError("Gamma should be a non-negative real number")
247
248
249
250

    result = img
    dtype = img.dtype
    if not torch.is_floating_point(img):
251
        result = convert_image_dtype(result, torch.float32)
252

253
    result = (gain * result**gamma).clamp(0, 1)
254

255
    result = convert_image_dtype(result, dtype)
256
257
258
    return result


vfdev's avatar
vfdev committed
259
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
260
    ratio = float(ratio)
261
262
    bound = 1.0 if img1.is_floating_point() else 255.0
    return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
263
264


265
def _rgb2hsv(img: Tensor) -> Tensor:
266
    r, g, b = img.unbind(dim=-3)
267

268
269
    # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
    # src/libImaging/Convert.c#L330
270
271
    maxc = torch.max(img, dim=-3).values
    minc = torch.min(img, dim=-3).values
272
273
274
275
276
277
278
279
280
281

    # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
    # from happening in the results, because
    #   + S channel has division by `maxc`, which is zero only if `maxc = minc`
    #   + H channel has division by `(maxc - minc)`.
    #
    # Instead of overwriting NaN afterwards, we just prevent it from occuring so
    # we don't need to deal with it in case we save the NaN in a buffer in
    # backprop, if it is ever supported, but it doesn't hurt to do so.
    eqc = maxc == minc
282
283

    cr = maxc - minc
284
    # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
285
286
    ones = torch.ones_like(maxc)
    s = cr / torch.where(eqc, ones, maxc)
287
288
289
290
    # Note that `eqc => maxc = minc = r = g = b`. So the following calculation
    # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
    # would not matter what values `rc`, `gc`, and `bc` have here, and thus
    # replacing denominator with 1 when `eqc` is fine.
291
    cr_divisor = torch.where(eqc, ones, cr)
292
293
294
    rc = (maxc - r) / cr_divisor
    gc = (maxc - g) / cr_divisor
    bc = (maxc - b) / cr_divisor
295
296
297
298

    hr = (maxc == r) * (bc - gc)
    hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
    hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
299
    h = hr + hg + hb
300
    h = torch.fmod((h / 6.0 + 1.0), 1.0)
301
    return torch.stack((h, s, maxc), dim=-3)
302
303


304
def _hsv2rgb(img: Tensor) -> Tensor:
305
    h, s, v = img.unbind(dim=-3)
306
307
308
309
310
311
312
313
314
    i = torch.floor(h * 6.0)
    f = (h * 6.0) - i
    i = i.to(dtype=torch.int32)

    p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
    q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
    t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
    i = i % 6

315
    mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
316

317
318
319
320
    a1 = torch.stack((v, q, p, p, t, v), dim=-3)
    a2 = torch.stack((t, v, v, q, p, p), dim=-3)
    a3 = torch.stack((p, p, t, v, v, q), dim=-3)
    a4 = torch.stack((a1, a2, a3), dim=-4)
321

322
    return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
323
324


325
326
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
    # padding is left, right, top, bottom
327
328
329

    # crop if needed
    if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
330
331
        neg_min_padding = [-min(x, 0) for x in padding]
        crop_left, crop_right, crop_top, crop_bottom = neg_min_padding
332
        img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right]
333
334
        padding = [max(x, 0) for x in padding]

335
336
    in_sizes = img.size()

337
    _x_indices = [i for i in range(in_sizes[-1])]  # [0, 1, 2, 3, ...]
338
339
    left_indices = [i for i in range(padding[0] - 1, -1, -1)]  # e.g. [3, 2, 1, 0]
    right_indices = [-(i + 1) for i in range(padding[1])]  # e.g. [-1, -2, -3]
340
    x_indices = torch.tensor(left_indices + _x_indices + right_indices, device=img.device)
341

342
    _y_indices = [i for i in range(in_sizes[-2])]
343
344
    top_indices = [i for i in range(padding[2] - 1, -1, -1)]
    bottom_indices = [-(i + 1) for i in range(padding[3])]
345
    y_indices = torch.tensor(top_indices + _y_indices + bottom_indices, device=img.device)
346
347
348
349
350
351
352
353
354
355

    ndim = img.ndim
    if ndim == 3:
        return img[:, y_indices[:, None], x_indices[None, :]]
    elif ndim == 4:
        return img[:, :, y_indices[:, None], x_indices[None, :]]
    else:
        raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")


356
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    if isinstance(padding, int):
        if torch.jit.is_scripting():
            # This maybe unreachable
            raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
        pad_left = pad_right = pad_top = pad_bottom = padding
    elif len(padding) == 1:
        pad_left = pad_right = pad_top = pad_bottom = padding[0]
    elif len(padding) == 2:
        pad_left = pad_right = padding[0]
        pad_top = pad_bottom = padding[1]
    else:
        pad_left = padding[0]
        pad_top = padding[1]
        pad_right = padding[2]
        pad_bottom = padding[3]

    return [pad_left, pad_right, pad_top, pad_bottom]


376
def pad(
377
    img: Tensor, padding: Union[int, List[int]], fill: Optional[Union[int, float]] = 0, padding_mode: str = "constant"
378
) -> Tensor:
379
    _assert_image_tensor(img)
380

381
382
383
    if fill is None:
        fill = 0

384
385
386
387
388
389
390
391
392
393
    if not isinstance(padding, (int, tuple, list)):
        raise TypeError("Got inappropriate padding arg")
    if not isinstance(fill, (int, float)):
        raise TypeError("Got inappropriate fill arg")
    if not isinstance(padding_mode, str):
        raise TypeError("Got inappropriate padding_mode arg")

    if isinstance(padding, tuple):
        padding = list(padding)

394
395
396
397
398
399
400
    if isinstance(padding, list):
        # TODO: Jit is failing on loading this op when scripted and saved
        # https://github.com/pytorch/pytorch/issues/81100
        if len(padding) not in [1, 2, 4]:
            raise ValueError(
                f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
            )
401

402
403
    if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
        raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
404

405
    p = _parse_pad_padding(padding)
406

407
408
409
    if padding_mode == "edge":
        # remap padding_mode str
        padding_mode = "replicate"
410
411
412
    elif padding_mode == "symmetric":
        # route to another implementation
        return _pad_symmetric(img, p)
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427

    need_squeeze = False
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    out_dtype = img.dtype
    need_cast = False
    if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
        # Here we temporary cast input tensor to float
        # until pytorch issue is resolved :
        # https://github.com/pytorch/pytorch/issues/40763
        need_cast = True
        img = img.to(torch.float32)

vfdev's avatar
vfdev committed
428
429
430
431
    if padding_mode in ("reflect", "replicate"):
        img = torch_pad(img, p, mode=padding_mode)
    else:
        img = torch_pad(img, p, mode=padding_mode, value=float(fill))
432
433
434
435
436
437
438

    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
        img = img.to(out_dtype)

439
    return img
vfdev's avatar
vfdev committed
440
441


442
443
444
445
def resize(
    img: Tensor,
    size: List[int],
    interpolation: str = "bilinear",
446
    antialias: Optional[bool] = None,
447
) -> Tensor:
448
    _assert_image_tensor(img)
vfdev's avatar
vfdev committed
449
450
451
452

    if isinstance(size, tuple):
        size = list(size)

453
454
455
    if antialias is None:
        antialias = False

456
457
    if antialias and interpolation not in ["bilinear", "bicubic"]:
        raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")
458

vfdev's avatar
vfdev committed
459
    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
vfdev's avatar
vfdev committed
460
461

    # Define align_corners to avoid warnings
462
    align_corners = False if interpolation in ["bilinear", "bicubic"] else None
vfdev's avatar
vfdev committed
463

464
    img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias)
vfdev's avatar
vfdev committed
465

466
    if interpolation == "bicubic" and out_dtype == torch.uint8:
vfdev's avatar
vfdev committed
467
        img = img.clamp(min=0, max=255)
vfdev's avatar
vfdev committed
468

vfdev's avatar
vfdev committed
469
    img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
vfdev's avatar
vfdev committed
470
471

    return img
vfdev's avatar
vfdev committed
472
473


vfdev's avatar
vfdev committed
474
def _assert_grid_transform_inputs(
475
476
477
    img: Tensor,
    matrix: Optional[List[float]],
    interpolation: str,
478
    fill: Optional[Union[int, float, List[float]]],
479
480
481
    supported_interpolation_modes: List[str],
    coeffs: Optional[List[float]] = None,
) -> None:
482
483
484
485
486

    if not (isinstance(img, torch.Tensor)):
        raise TypeError("Input img should be Tensor")

    _assert_image_tensor(img)
vfdev's avatar
vfdev committed
487

488
    if matrix is not None and not isinstance(matrix, list):
489
        raise TypeError("Argument matrix should be a list")
vfdev's avatar
vfdev committed
490

491
    if matrix is not None and len(matrix) != 6:
vfdev's avatar
vfdev committed
492
        raise ValueError("Argument matrix should have 6 float values")
vfdev's avatar
vfdev committed
493

494
495
496
    if coeffs is not None and len(coeffs) != 8:
        raise ValueError("Argument coeffs should have 8 float values")

497
498
499
500
    if fill is not None and not isinstance(fill, (int, float, tuple, list)):
        warnings.warn("Argument fill should be either int, float, tuple or list")

    # Check fill
501
    num_channels = get_dimensions(img)[0]
502
    if fill is not None and isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
503
504
505
506
        msg = (
            "The number of elements in 'fill' cannot broadcast to match the number of "
            "channels of the image ({} != {})"
        )
507
        raise ValueError(msg.format(len(fill), num_channels))
vfdev's avatar
vfdev committed
508

509
    if interpolation not in supported_interpolation_modes:
510
        raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
vfdev's avatar
vfdev committed
511
512


vfdev's avatar
vfdev committed
513
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
vfdev's avatar
vfdev committed
514
    need_squeeze = False
515
    # make image NCHW
vfdev's avatar
vfdev committed
516
517
518
519
520
521
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    out_dtype = img.dtype
    need_cast = False
vfdev's avatar
vfdev committed
522
    if out_dtype not in req_dtypes:
vfdev's avatar
vfdev committed
523
        need_cast = True
vfdev's avatar
vfdev committed
524
        req_dtype = req_dtypes[0]
525
526
        img = img.to(req_dtype)
    return img, need_cast, need_squeeze, out_dtype
vfdev's avatar
vfdev committed
527
528


529
def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor:
vfdev's avatar
vfdev committed
530
531
532
533
    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
vfdev's avatar
vfdev committed
534
535
536
537
        if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
            # it is better to round before cast
            img = torch.round(img)
        img = img.to(out_dtype)
vfdev's avatar
vfdev committed
538
539

    return img
vfdev's avatar
vfdev committed
540
541


542
543
544
def _apply_grid_transform(
    img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]]
) -> Tensor:
545

546
    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype])
547
548
549
550

    if img.shape[0] > 1:
        # Apply same grid to a batch of images
        grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
551
552
553

    # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
    if fill is not None:
554
555
        mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
        img = torch.cat((img, mask), dim=1)
556

557
558
    img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)

559
560
561
562
563
    # Fill with required color
    if fill is not None:
        mask = img[:, -1:, :, :]  # N * 1 * H * W
        img = img[:, :-1, :, :]  # N * C * H * W
        mask = mask.expand_as(img)
564
565
        fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1)
        fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
566
        if mode == "nearest":
567
568
569
570
571
            mask = mask < 0.5
            img[mask] = fill_img[mask]
        else:  # 'bilinear'
            img = img * mask + (1.0 - mask) * fill_img

572
573
574
575
    img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
    return img


576
def _gen_affine_grid(
577
578
579
580
581
    theta: Tensor,
    w: int,
    h: int,
    ow: int,
    oh: int,
582
583
584
585
586
587
588
589
) -> Tensor:
    # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
    # AffineGridGenerator.cpp#L18
    # Difference with AffineGridGenerator is that:
    # 1) we normalize grid values after applying theta
    # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate

    d = 0.5
590
    base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
591
592
593
594
    x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
595
596
    base_grid[..., 2].fill_(1)

597
598
    rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
    output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
599
600
601
    return output_grid.view(1, oh, ow, 2)


vfdev's avatar
vfdev committed
602
def affine(
603
    img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None
vfdev's avatar
vfdev committed
604
) -> Tensor:
605
    _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
vfdev's avatar
vfdev committed
606

607
608
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
vfdev's avatar
vfdev committed
609
    shape = img.shape
610
    # grid will be generated on the same device as theta and img
611
    grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
612
    return _apply_grid_transform(img, grid, interpolation, fill=fill)
vfdev's avatar
vfdev committed
613
614


vfdev's avatar
vfdev committed
615
def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
vfdev's avatar
vfdev committed
616

617
618
619
    # Inspired of PIL implementation:
    # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054

vfdev's avatar
vfdev committed
620
    # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
621
622
    # Points are shifted due to affine matrix torch convention about
    # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
623
624
625
626
627
628
629
630
    pts = torch.tensor(
        [
            [-0.5 * w, -0.5 * h, 1.0],
            [-0.5 * w, 0.5 * h, 1.0],
            [0.5 * w, 0.5 * h, 1.0],
            [0.5 * w, -0.5 * h, 1.0],
        ]
    )
631
632
    theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
    new_pts = torch.matmul(pts, theta.T)
vfdev's avatar
vfdev committed
633
634
635
    min_vals, _ = new_pts.min(dim=0)
    max_vals, _ = new_pts.max(dim=0)

636
637
638
639
    # shift points to [0, w] and [0, h] interval to match PIL results
    min_vals += torch.tensor((w * 0.5, h * 0.5))
    max_vals += torch.tensor((w * 0.5, h * 0.5))

640
641
642
643
644
    # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
    tol = 1e-4
    cmax = torch.ceil((max_vals / tol).trunc_() * tol)
    cmin = torch.floor((min_vals / tol).trunc_() * tol)
    size = cmax - cmin
645
    return int(size[0]), int(size[1])  # w, h
vfdev's avatar
vfdev committed
646
647
648


def rotate(
649
650
651
652
    img: Tensor,
    matrix: List[float],
    interpolation: str = "nearest",
    expand: bool = False,
653
    fill: Optional[Union[int, float, List[float]]] = None,
vfdev's avatar
vfdev committed
654
) -> Tensor:
655
    _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
656
    w, h = img.shape[-1], img.shape[-2]
vfdev's avatar
vfdev committed
657
    ow, oh = _compute_affine_output_size(matrix, w, h) if expand else (w, h)
658
659
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
660
    # grid will be generated on the same device as theta and img
661
    grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
662
663

    return _apply_grid_transform(img, grid, interpolation, fill=fill)
664
665


666
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> Tensor:
667
668
669
670
671
672
673
    # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
    # src/libImaging/Geometry.c#L394

    #
    # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #
674
675
676
677
    theta1 = torch.tensor(
        [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
    )
    theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
678
679

    d = 0.5
680
    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
681
682
683
684
    x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
685
686
    base_grid[..., 2].fill_(1)

687
    rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
688
    output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
689
690
691
692
693
694
695
    output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))

    output_grid = output_grid1 / output_grid2 - 1.0
    return output_grid.view(1, oh, ow, 2)


def perspective(
696
    img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[List[float]] = None
697
) -> Tensor:
698
    if not (isinstance(img, torch.Tensor)):
699
        raise TypeError("Input img should be Tensor.")
700
701

    _assert_image_tensor(img)
702
703
704
705

    _assert_grid_transform_inputs(
        img,
        matrix=None,
706
707
708
        interpolation=interpolation,
        fill=fill,
        supported_interpolation_modes=["nearest", "bilinear"],
709
        coeffs=perspective_coeffs,
710
711
712
    )

    ow, oh = img.shape[-1], img.shape[-2]
713
714
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
715
    return _apply_grid_transform(img, grid, interpolation, fill=fill)
716
717
718
719
720
721
722
723
724
725
726
727
728


def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
    ksize_half = (kernel_size - 1) * 0.5

    x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
    pdf = torch.exp(-0.5 * (x / sigma).pow(2))
    kernel1d = pdf / pdf.sum()

    return kernel1d


def _get_gaussian_kernel2d(
729
    kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
730
731
732
733
734
735
736
737
) -> Tensor:
    kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
    kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
    kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
    return kernel2d


def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
738
    if not (isinstance(img, torch.Tensor)):
739
        raise TypeError(f"img should be Tensor. Got {type(img)}")
740
741

    _assert_image_tensor(img)
742
743
744
745
746

    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
    kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

747
748
749
750
751
752
    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
        img,
        [
            kernel.dtype,
        ],
    )
753
754
755
756
757
758
759
760

    # padding = (left, right, top, bottom)
    padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
    img = torch_pad(img, padding, mode="reflect")
    img = conv2d(img, kernel, groups=img.shape[-3])

    img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
    return img
761
762
763


def invert(img: Tensor) -> Tensor:
764
765

    _assert_image_tensor(img)
766
767

    if img.ndim < 3:
768
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
769
770
771
772
773
774
775
776

    _assert_channels(img, [1, 3])

    bound = torch.tensor(1 if img.is_floating_point() else 255, dtype=img.dtype, device=img.device)
    return bound - img


def posterize(img: Tensor, bits: int) -> Tensor:
777
778

    _assert_image_tensor(img)
779
780

    if img.ndim < 3:
781
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
782
    if img.dtype != torch.uint8:
783
        raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
784
785

    _assert_channels(img, [1, 3])
786
    mask = -int(2 ** (8 - bits))  # JIT-friendly for: ~(2 ** (8 - bits) - 1)
787
788
789
790
    return img & mask


def solarize(img: Tensor, threshold: float) -> Tensor:
791
792

    _assert_image_tensor(img)
793
794

    if img.ndim < 3:
795
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
796
797
798

    _assert_channels(img, [1, 3])

puhuk's avatar
puhuk committed
799
800
    _assert_threshold(img, threshold)

801
802
803
804
805
806
807
808
809
810
811
812
    inverted_img = invert(img)
    return torch.where(img >= threshold, inverted_img, img)


def _blurred_degenerate_image(img: Tensor) -> Tensor:
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32

    kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
    kernel[1, 1] = 5.0
    kernel /= kernel.sum()
    kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

813
814
815
816
817
818
    result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
        img,
        [
            kernel.dtype,
        ],
    )
819
820
821
822
823
824
825
826
827
828
829
    result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
    result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)

    result = img.clone()
    result[..., 1:-1, 1:-1] = result_tmp

    return result


def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
    if sharpness_factor < 0:
830
        raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")
831

832
    _assert_image_tensor(img)
833
834
835
836
837
838
839
840
841
842

    _assert_channels(img, [1, 3])

    if img.size(-1) <= 2 or img.size(-2) <= 2:
        return img

    return _blend(img, _blurred_degenerate_image(img), sharpness_factor)


def autocontrast(img: Tensor) -> Tensor:
843
844

    _assert_image_tensor(img)
845
846

    if img.ndim < 3:
847
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
848
849
850
851
852
853
854
855
856

    _assert_channels(img, [1, 3])

    bound = 1.0 if img.is_floating_point() else 255.0
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32

    minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
    maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
    scale = bound / (maximum - minimum)
857
858
859
    eq_idxs = torch.isfinite(scale).logical_not()
    minimum[eq_idxs] = 0
    scale[eq_idxs] = 1
860
861
862
863

    return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)


864
def _scale_channel(img_chan: Tensor) -> Tensor:
865
866
867
868
869
870
871
872
    # TODO: we should expect bincount to always be faster than histc, but this
    # isn't always the case. Once
    # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
    # block and only use bincount.
    if img_chan.is_cuda:
        hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
    else:
        hist = torch.bincount(img_chan.view(-1), minlength=256)
873
874

    nonzero_hist = hist[hist != 0]
875
    step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
876
877
878
    if step == 0:
        return img_chan

879
    lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
880
881
882
883
884
885
886
887
888
889
    lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)

    return lut[img_chan.to(torch.int64)].to(torch.uint8)


def _equalize_single_image(img: Tensor) -> Tensor:
    return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])


def equalize(img: Tensor) -> Tensor:
890
891

    _assert_image_tensor(img)
892
893

    if not (3 <= img.ndim <= 4):
894
        raise TypeError(f"Input image tensor should have 3 or 4 dimensions, but found {img.ndim}")
895
    if img.dtype != torch.uint8:
896
        raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
897
898
899
900
901
902
903

    _assert_channels(img, [1, 3])

    if img.ndim == 3:
        return _equalize_single_image(img)

    return torch.stack([_equalize_single_image(x) for x in img])
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940


def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
    _assert_image_tensor(tensor)

    if not tensor.is_floating_point():
        raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")

    if tensor.ndim < 3:
        raise ValueError(
            f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
        )

    if not inplace:
        tensor = tensor.clone()

    dtype = tensor.dtype
    mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
    std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
    if (std == 0).any():
        raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
    if mean.ndim == 1:
        mean = mean.view(-1, 1, 1)
    if std.ndim == 1:
        std = std.view(-1, 1, 1)
    tensor.sub_(mean).div_(std)
    return tensor


def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
    _assert_image_tensor(img)

    if not inplace:
        img = img.clone()

    img[..., i : i + h, j : j + w] = v
    return img
941
942


943
944
945
946
947
948
def _create_identity_grid(size: List[int]) -> Tensor:
    hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size]
    grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij")
    return torch.stack([grid_x, grid_y], -1).unsqueeze(0)  # 1 x H x W x 2


949
950
951
952
953
954
955
956
957
958
959
960
961
def elastic_transform(
    img: Tensor,
    displacement: Tensor,
    interpolation: str = "bilinear",
    fill: Optional[List[float]] = None,
) -> Tensor:

    if not (isinstance(img, torch.Tensor)):
        raise TypeError(f"img should be Tensor. Got {type(img)}")

    size = list(img.shape[-2:])
    displacement = displacement.to(img.device)

962
    identity_grid = _create_identity_grid(size)
963
964
    grid = identity_grid.to(img.device) + displacement
    return _apply_grid_transform(img, grid, interpolation, fill)