functional_tensor.py 33.2 KB
Newer Older
vfdev's avatar
vfdev committed
1
import warnings
2
from typing import List, Optional, Tuple, Union
vfdev's avatar
vfdev committed
3

4
import torch
5
from torch import Tensor
6
from torch.nn.functional import conv2d, grid_sample, interpolate, pad as torch_pad
7
8


vfdev's avatar
vfdev committed
9
10
def _is_tensor_a_torch_image(x: Tensor) -> bool:
    return x.ndim >= 2
11
12


13
def _assert_image_tensor(img: Tensor) -> None:
14
15
16
17
    if not _is_tensor_a_torch_image(img):
        raise TypeError("Tensor is not a torch image.")


puhuk's avatar
puhuk committed
18
19
20
21
22
23
def _assert_threshold(img: Tensor, threshold: float) -> None:
    bound = 1 if img.is_floating_point() else 255
    if threshold > bound:
        raise TypeError("Threshold should be less than bound of img.")


24
25
26
27
28
29
30
def get_dimensions(img: Tensor) -> List[int]:
    _assert_image_tensor(img)
    channels = 1 if img.ndim == 2 else img.shape[-3]
    height, width = img.shape[-2:]
    return [channels, height, width]


31
def get_image_size(img: Tensor) -> List[int]:
32
    # Returns (w, h) of tensor image
33
34
    _assert_image_tensor(img)
    return [img.shape[-1], img.shape[-2]]
vfdev's avatar
vfdev committed
35
36


37
def get_image_num_channels(img: Tensor) -> int:
38
    _assert_image_tensor(img)
39
40
41
42
43
    if img.ndim == 2:
        return 1
    elif img.ndim > 2:
        return img.shape[-3]

44
    raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
45
46


47
48
def _max_value(dtype: torch.dtype) -> int:
    if dtype == torch.uint8:
49
        return 255
50
    elif dtype == torch.int8:
51
        return 127
52
    elif dtype == torch.int16:
53
        return 32767
54
    elif dtype == torch.int32:
55
        return 2147483647
56
    elif dtype == torch.int64:
57
        return 9223372036854775807
58
59
    else:
        return 1
60
61


62
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
63
    c = get_dimensions(img)[0]
64
    if c not in permitted:
65
        raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}")
66
67


68
69
70
71
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
    if image.dtype == dtype:
        return image

72
    if image.is_floating_point():
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

        # TODO: replace with dtype.is_floating_point when torchscript supports it
        if torch.tensor(0, dtype=dtype).is_floating_point():
            return image.to(dtype)

        # float to int
        if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
            image.dtype == torch.float64 and dtype == torch.int64
        ):
            msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
            raise RuntimeError(msg)

        # https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
        # For data in the range 0-1, (float * 255).to(uint) is only 255
        # when float is exactly 1.0.
        # `max + 1 - epsilon` provides more evenly distributed mapping of
        # ranges of floats to ints.
        eps = 1e-3
91
        max_val = float(_max_value(dtype))
92
93
94
        result = image.mul(max_val + 1.0 - eps)
        return result.to(dtype)
    else:
95
        input_max = float(_max_value(image.dtype))
96
97
98
99
100
101
102

        # int to float
        # TODO: replace with dtype.is_floating_point when torchscript supports it
        if torch.tensor(0, dtype=dtype).is_floating_point():
            image = image.to(dtype)
            return image / input_max

103
        output_max = float(_max_value(dtype))
104

105
106
107
108
109
        # int to int
        if input_max > output_max:
            # factor should be forced to int for torch jit script
            # otherwise factor is a float and image // factor can produce different results
            factor = int((input_max + 1) // (output_max + 1))
110
            image = torch.div(image, factor, rounding_mode="floor")
111
112
113
114
115
116
117
118
119
            return image.to(dtype)
        else:
            # factor should be forced to int for torch jit script
            # otherwise factor is a float and image * factor can produce different results
            factor = int((output_max + 1) // (input_max + 1))
            image = image.to(dtype)
            return image * factor


vfdev's avatar
vfdev committed
120
def vflip(img: Tensor) -> Tensor:
121
    _assert_image_tensor(img)
122

123
    return img.flip(-2)
124
125


vfdev's avatar
vfdev committed
126
def hflip(img: Tensor) -> Tensor:
127
    _assert_image_tensor(img)
128

129
    return img.flip(-1)
ekka's avatar
ekka committed
130
131


vfdev's avatar
vfdev committed
132
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
133
    _assert_image_tensor(img)
ekka's avatar
ekka committed
134

135
    _, h, w = get_dimensions(img)
136
137
138
139
    right = left + width
    bottom = top + height

    if left < 0 or top < 0 or right > w or bottom > h:
140
141
142
143
144
145
        padding_ltrb = [
            max(-left + min(0, right), 0),
            max(-top + min(0, bottom), 0),
            max(right - max(w, left), 0),
            max(bottom - max(h, top), 0),
        ]
146
        return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
147
    return img[..., top:bottom, left:right]
148
149


150
151
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
    if img.ndim < 3:
152
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
153
    _assert_channels(img, [1, 3])
154
155

    if num_output_channels not in (1, 3):
156
        raise ValueError("num_output_channels should be either 1 or 3")
157

158
159
160
161
162
163
164
165
    if img.shape[-3] == 3:
        r, g, b = img.unbind(dim=-3)
        # This implementation closely follows the TF one:
        # https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
        l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
        l_img = l_img.unsqueeze(dim=-3)
    else:
        l_img = img.clone()
166
167
168

    if num_output_channels == 3:
        return l_img.expand(img.shape)
169

170
    return l_img
171
172


vfdev's avatar
vfdev committed
173
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
174
    if brightness_factor < 0:
175
        raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
176

177
    _assert_image_tensor(img)
178

179
180
    _assert_channels(img, [1, 3])

181
    return _blend(img, torch.zeros_like(img), brightness_factor)
182
183


vfdev's avatar
vfdev committed
184
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
185
    if contrast_factor < 0:
186
        raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
187

188
    _assert_image_tensor(img)
189

190
    _assert_channels(img, [3, 1])
191
    c = get_dimensions(img)[0]
192
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
193
194
195
196
    if c == 3:
        mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
    else:
        mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)
197
198
199
200

    return _blend(img, mean, contrast_factor)


201
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
202
    if not (-0.5 <= hue_factor <= 0.5):
203
        raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
204

205
    if not (isinstance(img, torch.Tensor)):
206
        raise TypeError("Input img should be Tensor image")
207

208
209
    _assert_image_tensor(img)

210
    _assert_channels(img, [1, 3])
211
    if get_dimensions(img)[0] == 1:  # Match PIL behaviour
212
        return img
213

214
215
216
217
218
    orig_dtype = img.dtype
    if img.dtype == torch.uint8:
        img = img.to(dtype=torch.float32) / 255.0

    img = _rgb2hsv(img)
219
    h, s, v = img.unbind(dim=-3)
220
    h = (h + hue_factor) % 1.0
221
    img = torch.stack((h, s, v), dim=-3)
222
223
224
225
226
227
228
229
    img_hue_adj = _hsv2rgb(img)

    if orig_dtype == torch.uint8:
        img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)

    return img_hue_adj


vfdev's avatar
vfdev committed
230
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
231
    if saturation_factor < 0:
232
        raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
233

234
    _assert_image_tensor(img)
235

236
237
    _assert_channels(img, [1, 3])

238
    if get_dimensions(img)[0] == 1:  # Match PIL behaviour
239
        return img
240

241
    return _blend(img, rgb_to_grayscale(img), saturation_factor)
242
243


244
245
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
    if not isinstance(img, torch.Tensor):
246
        raise TypeError("Input img should be a Tensor.")
247

248
249
    _assert_channels(img, [1, 3])

250
    if gamma < 0:
251
        raise ValueError("Gamma should be a non-negative real number")
252
253
254
255

    result = img
    dtype = img.dtype
    if not torch.is_floating_point(img):
256
        result = convert_image_dtype(result, torch.float32)
257

258
    result = (gain * result**gamma).clamp(0, 1)
259

260
    result = convert_image_dtype(result, dtype)
261
262
263
    return result


vfdev's avatar
vfdev committed
264
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
265
    ratio = float(ratio)
266
267
    bound = 1.0 if img1.is_floating_point() else 255.0
    return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
268
269


270
def _rgb2hsv(img: Tensor) -> Tensor:
271
    r, g, b = img.unbind(dim=-3)
272

273
274
    # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
    # src/libImaging/Convert.c#L330
275
276
    maxc = torch.max(img, dim=-3).values
    minc = torch.min(img, dim=-3).values
277
278
279
280
281
282
283
284
285
286

    # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
    # from happening in the results, because
    #   + S channel has division by `maxc`, which is zero only if `maxc = minc`
    #   + H channel has division by `(maxc - minc)`.
    #
    # Instead of overwriting NaN afterwards, we just prevent it from occuring so
    # we don't need to deal with it in case we save the NaN in a buffer in
    # backprop, if it is ever supported, but it doesn't hurt to do so.
    eqc = maxc == minc
287
288

    cr = maxc - minc
289
    # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
290
291
    ones = torch.ones_like(maxc)
    s = cr / torch.where(eqc, ones, maxc)
292
293
294
295
    # Note that `eqc => maxc = minc = r = g = b`. So the following calculation
    # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
    # would not matter what values `rc`, `gc`, and `bc` have here, and thus
    # replacing denominator with 1 when `eqc` is fine.
296
    cr_divisor = torch.where(eqc, ones, cr)
297
298
299
    rc = (maxc - r) / cr_divisor
    gc = (maxc - g) / cr_divisor
    bc = (maxc - b) / cr_divisor
300
301
302
303

    hr = (maxc == r) * (bc - gc)
    hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
    hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
304
    h = hr + hg + hb
305
    h = torch.fmod((h / 6.0 + 1.0), 1.0)
306
    return torch.stack((h, s, maxc), dim=-3)
307
308


309
def _hsv2rgb(img: Tensor) -> Tensor:
310
    h, s, v = img.unbind(dim=-3)
311
312
313
314
315
316
317
318
319
    i = torch.floor(h * 6.0)
    f = (h * 6.0) - i
    i = i.to(dtype=torch.int32)

    p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
    q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
    t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
    i = i % 6

320
    mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
321

322
323
324
325
    a1 = torch.stack((v, q, p, p, t, v), dim=-3)
    a2 = torch.stack((t, v, v, q, p, p), dim=-3)
    a3 = torch.stack((p, p, t, v, v, q), dim=-3)
    a4 = torch.stack((a1, a2, a3), dim=-4)
326

327
    return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
328
329


330
331
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
    # padding is left, right, top, bottom
332
333
334

    # crop if needed
    if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
335
336
        neg_min_padding = [-min(x, 0) for x in padding]
        crop_left, crop_right, crop_top, crop_bottom = neg_min_padding
337
        img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right]
338
339
        padding = [max(x, 0) for x in padding]

340
341
    in_sizes = img.size()

342
    _x_indices = [i for i in range(in_sizes[-1])]  # [0, 1, 2, 3, ...]
343
344
    left_indices = [i for i in range(padding[0] - 1, -1, -1)]  # e.g. [3, 2, 1, 0]
    right_indices = [-(i + 1) for i in range(padding[1])]  # e.g. [-1, -2, -3]
345
    x_indices = torch.tensor(left_indices + _x_indices + right_indices, device=img.device)
346

347
    _y_indices = [i for i in range(in_sizes[-2])]
348
349
    top_indices = [i for i in range(padding[2] - 1, -1, -1)]
    bottom_indices = [-(i + 1) for i in range(padding[3])]
350
    y_indices = torch.tensor(top_indices + _y_indices + bottom_indices, device=img.device)
351
352
353
354
355
356
357
358
359
360

    ndim = img.ndim
    if ndim == 3:
        return img[:, y_indices[:, None], x_indices[None, :]]
    elif ndim == 4:
        return img[:, :, y_indices[:, None], x_indices[None, :]]
    else:
        raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")


361
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    if isinstance(padding, int):
        if torch.jit.is_scripting():
            # This maybe unreachable
            raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
        pad_left = pad_right = pad_top = pad_bottom = padding
    elif len(padding) == 1:
        pad_left = pad_right = pad_top = pad_bottom = padding[0]
    elif len(padding) == 2:
        pad_left = pad_right = padding[0]
        pad_top = pad_bottom = padding[1]
    else:
        pad_left = padding[0]
        pad_top = padding[1]
        pad_right = padding[2]
        pad_bottom = padding[3]

    return [pad_left, pad_right, pad_top, pad_bottom]


381
def pad(
382
    img: Tensor, padding: Union[int, List[int]], fill: Optional[Union[int, float]] = 0, padding_mode: str = "constant"
383
) -> Tensor:
384
    _assert_image_tensor(img)
385

386
387
388
    if fill is None:
        fill = 0

389
390
391
392
393
394
395
396
397
398
    if not isinstance(padding, (int, tuple, list)):
        raise TypeError("Got inappropriate padding arg")
    if not isinstance(fill, (int, float)):
        raise TypeError("Got inappropriate fill arg")
    if not isinstance(padding_mode, str):
        raise TypeError("Got inappropriate padding_mode arg")

    if isinstance(padding, tuple):
        padding = list(padding)

399
400
401
402
403
404
405
    if isinstance(padding, list):
        # TODO: Jit is failing on loading this op when scripted and saved
        # https://github.com/pytorch/pytorch/issues/81100
        if len(padding) not in [1, 2, 4]:
            raise ValueError(
                f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
            )
406

407
408
    if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
        raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
409

410
    p = _parse_pad_padding(padding)
411

412
413
414
    if padding_mode == "edge":
        # remap padding_mode str
        padding_mode = "replicate"
415
416
417
    elif padding_mode == "symmetric":
        # route to another implementation
        return _pad_symmetric(img, p)
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432

    need_squeeze = False
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    out_dtype = img.dtype
    need_cast = False
    if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
        # Here we temporary cast input tensor to float
        # until pytorch issue is resolved :
        # https://github.com/pytorch/pytorch/issues/40763
        need_cast = True
        img = img.to(torch.float32)

vfdev's avatar
vfdev committed
433
434
435
436
    if padding_mode in ("reflect", "replicate"):
        img = torch_pad(img, p, mode=padding_mode)
    else:
        img = torch_pad(img, p, mode=padding_mode, value=float(fill))
437
438
439
440
441
442
443

    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
        img = img.to(out_dtype)

444
    return img
vfdev's avatar
vfdev committed
445
446


447
448
449
450
def resize(
    img: Tensor,
    size: List[int],
    interpolation: str = "bilinear",
451
    antialias: Optional[bool] = None,
452
) -> Tensor:
453
    _assert_image_tensor(img)
vfdev's avatar
vfdev committed
454
455
456
457

    if isinstance(size, tuple):
        size = list(size)

458
459
460
    if antialias is None:
        antialias = False

461
462
    if antialias and interpolation not in ["bilinear", "bicubic"]:
        raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")
463

vfdev's avatar
vfdev committed
464
    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
vfdev's avatar
vfdev committed
465
466

    # Define align_corners to avoid warnings
467
    align_corners = False if interpolation in ["bilinear", "bicubic"] else None
vfdev's avatar
vfdev committed
468

469
    img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias)
vfdev's avatar
vfdev committed
470

471
    if interpolation == "bicubic" and out_dtype == torch.uint8:
vfdev's avatar
vfdev committed
472
        img = img.clamp(min=0, max=255)
vfdev's avatar
vfdev committed
473

vfdev's avatar
vfdev committed
474
    img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
vfdev's avatar
vfdev committed
475
476

    return img
vfdev's avatar
vfdev committed
477
478


vfdev's avatar
vfdev committed
479
def _assert_grid_transform_inputs(
480
481
482
    img: Tensor,
    matrix: Optional[List[float]],
    interpolation: str,
483
    fill: Optional[Union[int, float, List[float]]],
484
485
486
    supported_interpolation_modes: List[str],
    coeffs: Optional[List[float]] = None,
) -> None:
487
488
489
490
491

    if not (isinstance(img, torch.Tensor)):
        raise TypeError("Input img should be Tensor")

    _assert_image_tensor(img)
vfdev's avatar
vfdev committed
492

493
    if matrix is not None and not isinstance(matrix, list):
494
        raise TypeError("Argument matrix should be a list")
vfdev's avatar
vfdev committed
495

496
    if matrix is not None and len(matrix) != 6:
vfdev's avatar
vfdev committed
497
        raise ValueError("Argument matrix should have 6 float values")
vfdev's avatar
vfdev committed
498

499
500
501
    if coeffs is not None and len(coeffs) != 8:
        raise ValueError("Argument coeffs should have 8 float values")

502
503
504
505
    if fill is not None and not isinstance(fill, (int, float, tuple, list)):
        warnings.warn("Argument fill should be either int, float, tuple or list")

    # Check fill
506
    num_channels = get_dimensions(img)[0]
507
    if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels:
508
509
510
511
        msg = (
            "The number of elements in 'fill' cannot broadcast to match the number of "
            "channels of the image ({} != {})"
        )
512
        raise ValueError(msg.format(len(fill), num_channels))
vfdev's avatar
vfdev committed
513

514
    if interpolation not in supported_interpolation_modes:
515
        raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
vfdev's avatar
vfdev committed
516
517


vfdev's avatar
vfdev committed
518
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
vfdev's avatar
vfdev committed
519
    need_squeeze = False
520
    # make image NCHW
vfdev's avatar
vfdev committed
521
522
523
524
525
526
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    out_dtype = img.dtype
    need_cast = False
vfdev's avatar
vfdev committed
527
    if out_dtype not in req_dtypes:
vfdev's avatar
vfdev committed
528
        need_cast = True
vfdev's avatar
vfdev committed
529
        req_dtype = req_dtypes[0]
530
531
        img = img.to(req_dtype)
    return img, need_cast, need_squeeze, out_dtype
vfdev's avatar
vfdev committed
532
533


534
def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor:
vfdev's avatar
vfdev committed
535
536
537
538
    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
vfdev's avatar
vfdev committed
539
540
541
542
        if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
            # it is better to round before cast
            img = torch.round(img)
        img = img.to(out_dtype)
vfdev's avatar
vfdev committed
543
544

    return img
vfdev's avatar
vfdev committed
545
546


547
548
549
def _apply_grid_transform(
    img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]]
) -> Tensor:
550

551
    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype])
552
553
554
555

    if img.shape[0] > 1:
        # Apply same grid to a batch of images
        grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
556
557
558

    # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
    if fill is not None:
559
560
        mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
        img = torch.cat((img, mask), dim=1)
561

562
563
    img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)

564
565
566
567
568
    # Fill with required color
    if fill is not None:
        mask = img[:, -1:, :, :]  # N * 1 * H * W
        img = img[:, :-1, :, :]  # N * C * H * W
        mask = mask.expand_as(img)
569
570
        fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1)
        fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
571
        if mode == "nearest":
572
573
574
575
576
            mask = mask < 0.5
            img[mask] = fill_img[mask]
        else:  # 'bilinear'
            img = img * mask + (1.0 - mask) * fill_img

577
578
579
580
    img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
    return img


581
def _gen_affine_grid(
582
583
584
585
586
    theta: Tensor,
    w: int,
    h: int,
    ow: int,
    oh: int,
587
588
589
590
591
592
593
594
) -> Tensor:
    # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
    # AffineGridGenerator.cpp#L18
    # Difference with AffineGridGenerator is that:
    # 1) we normalize grid values after applying theta
    # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate

    d = 0.5
595
    base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
596
597
598
599
    x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
600
601
    base_grid[..., 2].fill_(1)

602
603
    rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
    output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
604
605
606
    return output_grid.view(1, oh, ow, 2)


vfdev's avatar
vfdev committed
607
def affine(
608
609
610
611
    img: Tensor,
    matrix: List[float],
    interpolation: str = "nearest",
    fill: Optional[Union[int, float, List[float]]] = None,
vfdev's avatar
vfdev committed
612
) -> Tensor:
613
    _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
vfdev's avatar
vfdev committed
614

615
616
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
vfdev's avatar
vfdev committed
617
    shape = img.shape
618
    # grid will be generated on the same device as theta and img
619
    grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
620
    return _apply_grid_transform(img, grid, interpolation, fill=fill)
vfdev's avatar
vfdev committed
621
622


vfdev's avatar
vfdev committed
623
def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
vfdev's avatar
vfdev committed
624

625
626
627
    # Inspired of PIL implementation:
    # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054

vfdev's avatar
vfdev committed
628
    # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
629
630
    # Points are shifted due to affine matrix torch convention about
    # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
631
632
633
634
635
636
637
638
    pts = torch.tensor(
        [
            [-0.5 * w, -0.5 * h, 1.0],
            [-0.5 * w, 0.5 * h, 1.0],
            [0.5 * w, 0.5 * h, 1.0],
            [0.5 * w, -0.5 * h, 1.0],
        ]
    )
639
640
    theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
    new_pts = torch.matmul(pts, theta.T)
vfdev's avatar
vfdev committed
641
642
643
    min_vals, _ = new_pts.min(dim=0)
    max_vals, _ = new_pts.max(dim=0)

644
645
646
647
    # shift points to [0, w] and [0, h] interval to match PIL results
    min_vals += torch.tensor((w * 0.5, h * 0.5))
    max_vals += torch.tensor((w * 0.5, h * 0.5))

648
649
650
651
652
    # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
    tol = 1e-4
    cmax = torch.ceil((max_vals / tol).trunc_() * tol)
    cmin = torch.floor((min_vals / tol).trunc_() * tol)
    size = cmax - cmin
653
    return int(size[0]), int(size[1])  # w, h
vfdev's avatar
vfdev committed
654
655
656


def rotate(
657
658
659
660
    img: Tensor,
    matrix: List[float],
    interpolation: str = "nearest",
    expand: bool = False,
661
    fill: Optional[Union[int, float, List[float]]] = None,
vfdev's avatar
vfdev committed
662
) -> Tensor:
663
    _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
664
    w, h = img.shape[-1], img.shape[-2]
vfdev's avatar
vfdev committed
665
    ow, oh = _compute_affine_output_size(matrix, w, h) if expand else (w, h)
666
667
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
668
    # grid will be generated on the same device as theta and img
669
    grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
670
671

    return _apply_grid_transform(img, grid, interpolation, fill=fill)
672
673


674
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> Tensor:
675
676
677
678
679
680
681
    # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
    # src/libImaging/Geometry.c#L394

    #
    # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #
682
683
684
685
    theta1 = torch.tensor(
        [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
    )
    theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
686
687

    d = 0.5
688
    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
689
690
691
692
    x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
693
694
    base_grid[..., 2].fill_(1)

695
    rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
696
    output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
697
698
699
700
701
702
703
    output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))

    output_grid = output_grid1 / output_grid2 - 1.0
    return output_grid.view(1, oh, ow, 2)


def perspective(
704
705
706
707
    img: Tensor,
    perspective_coeffs: List[float],
    interpolation: str = "bilinear",
    fill: Optional[Union[int, float, List[float]]] = None,
708
) -> Tensor:
709
    if not (isinstance(img, torch.Tensor)):
710
        raise TypeError("Input img should be Tensor.")
711
712

    _assert_image_tensor(img)
713
714
715
716

    _assert_grid_transform_inputs(
        img,
        matrix=None,
717
718
719
        interpolation=interpolation,
        fill=fill,
        supported_interpolation_modes=["nearest", "bilinear"],
720
        coeffs=perspective_coeffs,
721
722
723
    )

    ow, oh = img.shape[-1], img.shape[-2]
724
725
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
726
    return _apply_grid_transform(img, grid, interpolation, fill=fill)
727
728
729
730
731
732
733
734
735
736
737
738
739


def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
    ksize_half = (kernel_size - 1) * 0.5

    x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
    pdf = torch.exp(-0.5 * (x / sigma).pow(2))
    kernel1d = pdf / pdf.sum()

    return kernel1d


def _get_gaussian_kernel2d(
740
    kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
741
742
743
744
745
746
747
748
) -> Tensor:
    kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
    kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
    kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
    return kernel2d


def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
749
    if not (isinstance(img, torch.Tensor)):
750
        raise TypeError(f"img should be Tensor. Got {type(img)}")
751
752

    _assert_image_tensor(img)
753
754
755
756
757

    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
    kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

758
759
760
761
762
763
    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
        img,
        [
            kernel.dtype,
        ],
    )
764
765
766
767
768
769
770
771

    # padding = (left, right, top, bottom)
    padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
    img = torch_pad(img, padding, mode="reflect")
    img = conv2d(img, kernel, groups=img.shape[-3])

    img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
    return img
772
773
774


def invert(img: Tensor) -> Tensor:
775
776

    _assert_image_tensor(img)
777
778

    if img.ndim < 3:
779
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
780
781
782
783
784
785
786
787

    _assert_channels(img, [1, 3])

    bound = torch.tensor(1 if img.is_floating_point() else 255, dtype=img.dtype, device=img.device)
    return bound - img


def posterize(img: Tensor, bits: int) -> Tensor:
788
789

    _assert_image_tensor(img)
790
791

    if img.ndim < 3:
792
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
793
    if img.dtype != torch.uint8:
794
        raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
795
796

    _assert_channels(img, [1, 3])
797
    mask = -int(2 ** (8 - bits))  # JIT-friendly for: ~(2 ** (8 - bits) - 1)
798
799
800
801
    return img & mask


def solarize(img: Tensor, threshold: float) -> Tensor:
802
803

    _assert_image_tensor(img)
804
805

    if img.ndim < 3:
806
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
807
808
809

    _assert_channels(img, [1, 3])

puhuk's avatar
puhuk committed
810
811
    _assert_threshold(img, threshold)

812
813
814
815
816
817
818
819
820
821
822
823
    inverted_img = invert(img)
    return torch.where(img >= threshold, inverted_img, img)


def _blurred_degenerate_image(img: Tensor) -> Tensor:
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32

    kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
    kernel[1, 1] = 5.0
    kernel /= kernel.sum()
    kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

824
825
826
827
828
829
    result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
        img,
        [
            kernel.dtype,
        ],
    )
830
831
832
833
834
835
836
837
838
839
840
    result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
    result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)

    result = img.clone()
    result[..., 1:-1, 1:-1] = result_tmp

    return result


def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
    if sharpness_factor < 0:
841
        raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")
842

843
    _assert_image_tensor(img)
844
845
846
847
848
849
850
851
852
853

    _assert_channels(img, [1, 3])

    if img.size(-1) <= 2 or img.size(-2) <= 2:
        return img

    return _blend(img, _blurred_degenerate_image(img), sharpness_factor)


def autocontrast(img: Tensor) -> Tensor:
854
855

    _assert_image_tensor(img)
856
857

    if img.ndim < 3:
858
        raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
859
860
861
862
863
864
865
866
867

    _assert_channels(img, [1, 3])

    bound = 1.0 if img.is_floating_point() else 255.0
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32

    minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
    maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
    scale = bound / (maximum - minimum)
868
869
870
    eq_idxs = torch.isfinite(scale).logical_not()
    minimum[eq_idxs] = 0
    scale[eq_idxs] = 1
871
872
873
874

    return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)


875
def _scale_channel(img_chan: Tensor) -> Tensor:
876
877
878
879
880
881
882
883
    # TODO: we should expect bincount to always be faster than histc, but this
    # isn't always the case. Once
    # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
    # block and only use bincount.
    if img_chan.is_cuda:
        hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
    else:
        hist = torch.bincount(img_chan.view(-1), minlength=256)
884
885

    nonzero_hist = hist[hist != 0]
886
    step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
887
888
889
    if step == 0:
        return img_chan

890
    lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
891
892
893
894
895
896
897
898
899
900
    lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)

    return lut[img_chan.to(torch.int64)].to(torch.uint8)


def _equalize_single_image(img: Tensor) -> Tensor:
    return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])


def equalize(img: Tensor) -> Tensor:
901
902

    _assert_image_tensor(img)
903
904

    if not (3 <= img.ndim <= 4):
905
        raise TypeError(f"Input image tensor should have 3 or 4 dimensions, but found {img.ndim}")
906
    if img.dtype != torch.uint8:
907
        raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
908
909
910
911
912
913
914

    _assert_channels(img, [1, 3])

    if img.ndim == 3:
        return _equalize_single_image(img)

    return torch.stack([_equalize_single_image(x) for x in img])
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951


def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
    _assert_image_tensor(tensor)

    if not tensor.is_floating_point():
        raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")

    if tensor.ndim < 3:
        raise ValueError(
            f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
        )

    if not inplace:
        tensor = tensor.clone()

    dtype = tensor.dtype
    mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
    std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
    if (std == 0).any():
        raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
    if mean.ndim == 1:
        mean = mean.view(-1, 1, 1)
    if std.ndim == 1:
        std = std.view(-1, 1, 1)
    tensor.sub_(mean).div_(std)
    return tensor


def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
    _assert_image_tensor(img)

    if not inplace:
        img = img.clone()

    img[..., i : i + h, j : j + w] = v
    return img
952
953


954
955
956
957
958
959
def _create_identity_grid(size: List[int]) -> Tensor:
    hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size]
    grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij")
    return torch.stack([grid_x, grid_y], -1).unsqueeze(0)  # 1 x H x W x 2


960
961
962
963
def elastic_transform(
    img: Tensor,
    displacement: Tensor,
    interpolation: str = "bilinear",
964
    fill: Optional[Union[int, float, List[float]]] = None,
965
966
967
968
969
970
971
972
) -> Tensor:

    if not (isinstance(img, torch.Tensor)):
        raise TypeError(f"img should be Tensor. Got {type(img)}")

    size = list(img.shape[-2:])
    displacement = displacement.to(img.device)

973
    identity_grid = _create_identity_grid(size)
974
975
    grid = identity_grid.to(img.device) + displacement
    return _apply_grid_transform(img, grid, interpolation, fill)