functional_tensor.py 23.9 KB
Newer Older
vfdev's avatar
vfdev committed
1
2
3
import warnings
from typing import Optional

4
import torch
5
from torch import Tensor
vfdev's avatar
vfdev committed
6
from torch.nn.functional import affine_grid, grid_sample
7
from torch.jit.annotations import List, BroadcastingList2
8
9


vfdev's avatar
vfdev committed
10
11
def _is_tensor_a_torch_image(x: Tensor) -> bool:
    return x.ndim >= 2
12
13


vfdev's avatar
vfdev committed
14
def _get_image_size(img: Tensor) -> List[int]:
vfdev's avatar
vfdev committed
15
    """Returns (w, h) of tensor image"""
vfdev's avatar
vfdev committed
16
17
18
19
20
21
    if _is_tensor_a_torch_image(img):
        return [img.shape[-1], img.shape[-2]]
    raise TypeError("Unexpected type {}".format(type(img)))


def vflip(img: Tensor) -> Tensor:
22
23
24
    """Vertically flip the given the Image Tensor.

    Args:
25
        img (Tensor): Image Tensor to be flipped in the form [C, H, W].
26
27
28
29

    Returns:
        Tensor:  Vertically flipped image Tensor.
    """
30
    if not _is_tensor_a_torch_image(img):
31
32
        raise TypeError('tensor is not a torch image.')

33
    return img.flip(-2)
34
35


vfdev's avatar
vfdev committed
36
def hflip(img: Tensor) -> Tensor:
37
38
39
    """Horizontally flip the given the Image Tensor.

    Args:
40
        img (Tensor): Image Tensor to be flipped in the form [C, H, W].
41
42
43
44

    Returns:
        Tensor:  Horizontally flipped image Tensor.
    """
45
    if not _is_tensor_a_torch_image(img):
46
47
        raise TypeError('tensor is not a torch image.')

48
    return img.flip(-1)
ekka's avatar
ekka committed
49
50


vfdev's avatar
vfdev committed
51
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
ekka's avatar
ekka committed
52
    """Crop the given Image Tensor.
53

ekka's avatar
ekka committed
54
    Args:
vfdev's avatar
vfdev committed
55
        img (Tensor): Image to be cropped in the form [..., H, W]. (0,0) denotes the top left corner of the image.
ekka's avatar
ekka committed
56
57
58
59
        top (int): Vertical component of the top left corner of the crop box.
        left (int): Horizontal component of the top left corner of the crop box.
        height (int): Height of the crop box.
        width (int): Width of the crop box.
60

ekka's avatar
ekka committed
61
62
63
    Returns:
        Tensor: Cropped image.
    """
64
    if not _is_tensor_a_torch_image(img):
vfdev's avatar
vfdev committed
65
        raise TypeError("tensor is not a torch image.")
ekka's avatar
ekka committed
66
67

    return img[..., top:top + height, left:left + width]
68
69


vfdev's avatar
vfdev committed
70
def rgb_to_grayscale(img: Tensor) -> Tensor:
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    """Convert the given RGB Image Tensor to Grayscale.
    For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which
    is L = R * 0.2989 + G * 0.5870 + B * 0.1140

    Args:
        img (Tensor): Image to be converted to Grayscale in the form [C, H, W].

    Returns:
        Tensor: Grayscale image.

    """
    if img.shape[0] != 3:
        raise TypeError('Input Image does not contain 3 Channels')

    return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype)


vfdev's avatar
vfdev committed
88
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
89
90
91
92
93
94
95
96
97
98
99
    """Adjust brightness of an RGB image.

    Args:
        img (Tensor): Image to be adjusted.
        brightness_factor (float):  How much to adjust the brightness. Can be
            any non negative number. 0 gives a black image, 1 gives the
            original image while 2 increases the brightness by a factor of 2.

    Returns:
        Tensor: Brightness adjusted image.
    """
100
101
102
    if brightness_factor < 0:
        raise ValueError('brightness_factor ({}) is not non-negative.'.format(brightness_factor))

103
    if not _is_tensor_a_torch_image(img):
104
105
        raise TypeError('tensor is not a torch image.')

106
    return _blend(img, torch.zeros_like(img), brightness_factor)
107
108


vfdev's avatar
vfdev committed
109
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
110
111
112
113
114
115
116
117
118
119
120
    """Adjust contrast of an RGB image.

    Args:
        img (Tensor): Image to be adjusted.
        contrast_factor (float): How much to adjust the contrast. Can be any
            non negative number. 0 gives a solid gray image, 1 gives the
            original image while 2 increases the contrast by a factor of 2.

    Returns:
        Tensor: Contrast adjusted image.
    """
121
122
123
    if contrast_factor < 0:
        raise ValueError('contrast_factor ({}) is not non-negative.'.format(contrast_factor))

124
    if not _is_tensor_a_torch_image(img):
125
126
        raise TypeError('tensor is not a torch image.')

127
    mean = torch.mean(rgb_to_grayscale(img).to(torch.float))
128
129
130
131

    return _blend(img, mean, contrast_factor)


132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def adjust_hue(img, hue_factor):
    """Adjust hue of an image.

    The image hue is adjusted by converting the image to HSV and
    cyclically shifting the intensities in the hue channel (H).
    The image is then converted back to original image mode.

    `hue_factor` is the amount of shift in H channel and must be in the
    interval `[-0.5, 0.5]`.

    See `Hue`_ for more details.

    .. _Hue: https://en.wikipedia.org/wiki/Hue

    Args:
        img (Tensor): Image to be adjusted. Image type is either uint8 or float.
        hue_factor (float):  How much to shift the hue channel. Should be in
            [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
            HSV space in positive and negative direction respectively.
            0 means no shift. Therefore, both -0.5 and 0.5 will give an image
            with complementary colors while 0 gives the original image.

    Returns:
         Tensor: Hue adjusted image.
    """
157
    if not (-0.5 <= hue_factor <= 0.5):
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))

    if not _is_tensor_a_torch_image(img):
        raise TypeError('tensor is not a torch image.')

    orig_dtype = img.dtype
    if img.dtype == torch.uint8:
        img = img.to(dtype=torch.float32) / 255.0

    img = _rgb2hsv(img)
    h, s, v = img.unbind(0)
    h += hue_factor
    h = h % 1.0
    img = torch.stack((h, s, v))
    img_hue_adj = _hsv2rgb(img)

    if orig_dtype == torch.uint8:
        img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)

    return img_hue_adj


vfdev's avatar
vfdev committed
180
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
181
182
183
184
    """Adjust color saturation of an RGB image.

    Args:
        img (Tensor): Image to be adjusted.
185
186
187
        saturation_factor (float):  How much to adjust the saturation. Can be any
            non negative number. 0 gives a black and white image, 1 gives the
            original image while 2 enhances the saturation by a factor of 2.
188
189
190
191

    Returns:
        Tensor: Saturation adjusted image.
    """
192
193
194
    if saturation_factor < 0:
        raise ValueError('saturation_factor ({}) is not non-negative.'.format(saturation_factor))

195
    if not _is_tensor_a_torch_image(img):
196
197
        raise TypeError('tensor is not a torch image.')

198
    return _blend(img, rgb_to_grayscale(img), saturation_factor)
199
200


201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
    r"""Adjust gamma of an RGB image.

    Also known as Power Law Transform. Intensities in RGB mode are adjusted
    based on the following equation:

    .. math::
        `I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}`

    See `Gamma Correction`_ for more details.

    .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction

    Args:
        img (Tensor): Tensor of RBG values to be adjusted.
        gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
            gamma larger than 1 make the shadows darker,
            while gamma smaller than 1 make dark regions lighter.
        gain (float): The constant multiplier.
    """

    if not isinstance(img, torch.Tensor):
        raise TypeError('img should be a Tensor. Got {}'.format(type(img)))

    if gamma < 0:
        raise ValueError('Gamma should be a non-negative real number')

    result = img
    dtype = img.dtype
    if not torch.is_floating_point(img):
        result = result / 255.0

    result = (gain * result ** gamma).clamp(0, 1)

    if result.dtype != dtype:
        eps = 1e-3
        result = (255 + 1.0 - eps) * result
    result = result.to(dtype)
    return result


vfdev's avatar
vfdev committed
242
def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
243
244
245
    """Crop the Image Tensor and resize it to desired size.

    Args:
vfdev's avatar
vfdev committed
246
        img (Tensor): Image to be cropped.
247
248
249
250
251
252
        output_size (sequence or int): (height, width) of the crop box. If int,
                it is used for both directions

    Returns:
            Tensor: Cropped image.
    """
253
    if not _is_tensor_a_torch_image(img):
254
255
256
257
        raise TypeError('tensor is not a torch image.')

    _, image_width, image_height = img.size()
    crop_height, crop_width = output_size
vfdev's avatar
vfdev committed
258
259
260
261
262
263
264
265
    # crop_top = int(round((image_height - crop_height) / 2.))
    # Result can be different between python func and scripted func
    # Temporary workaround:
    crop_top = int((image_height - crop_height + 1) * 0.5)
    # crop_left = int(round((image_width - crop_width) / 2.))
    # Result can be different between python func and scripted func
    # Temporary workaround:
    crop_left = int((image_width - crop_width + 1) * 0.5)
266
267
268
269

    return crop(img, crop_top, crop_left, crop_height, crop_width)


vfdev's avatar
vfdev committed
270
def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
271
272
    """Crop the given Image Tensor into four corners and the central crop.
    .. Note::
273
        This transform returns a List of Tensors and there may be a
274
275
276
        mismatch in the number of inputs and targets your ``Dataset`` returns.

    Args:
vfdev's avatar
vfdev committed
277
278
279
280
        img (Tensor): Image to be cropped.
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
281
282

    Returns:
283
       List: List (tl, tr, bl, br, center)
284
285
                Corresponding top left, top right, bottom left, bottom right and center crop.
    """
286
    if not _is_tensor_a_torch_image(img):
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
        raise TypeError('tensor is not a torch image.')

    assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

    _, image_width, image_height = img.size()
    crop_height, crop_width = size
    if crop_width > image_width or crop_height > image_height:
        msg = "Requested crop size {} is bigger than input size {}"
        raise ValueError(msg.format(size, (image_height, image_width)))

    tl = crop(img, 0, 0, crop_width, crop_height)
    tr = crop(img, image_width - crop_width, 0, image_width, crop_height)
    bl = crop(img, 0, image_height - crop_height, crop_width, image_height)
    br = crop(img, image_width - crop_width, image_height - crop_height, image_width, image_height)
    center = center_crop(img, (crop_height, crop_width))

303
    return [tl, tr, bl, br, center]
304
305


vfdev's avatar
vfdev committed
306
def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = False) -> List[Tensor]:
307
308
    """Crop the given Image Tensor into four corners and the central crop plus the
        flipped version of these (horizontal flipping is used by default).
vfdev's avatar
vfdev committed
309

310
    .. Note::
311
        This transform returns a List of images and there may be a
312
313
314
        mismatch in the number of inputs and targets your ``Dataset`` returns.

    Args:
vfdev's avatar
vfdev committed
315
316
        img (Tensor): Image to be cropped.
        size (sequence or int): Desired output size of the crop. If size is an
317
318
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
vfdev's avatar
vfdev committed
319
        vertical_flip (bool): Use vertical flipping instead of horizontal
320
321

    Returns:
322
       List: List (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
323
324
325
                Corresponding top left, top right, bottom left, bottom right and center crop
                and same for the flipped image's tensor.
    """
326
    if not _is_tensor_a_torch_image(img):
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        raise TypeError('tensor is not a torch image.')

    assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
    first_five = five_crop(img, size)

    if vertical_flip:
        img = vflip(img)
    else:
        img = hflip(img)

    second_five = five_crop(img, size)

    return first_five + second_five


vfdev's avatar
vfdev committed
342
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
343
    bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255
344
    return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)
345
346
347
348
349


def _rgb2hsv(img):
    r, g, b = img.unbind(0)

350
351
352
353
354
355
356
357
358
359
360
361
    maxc = torch.max(img, dim=0).values
    minc = torch.min(img, dim=0).values

    # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
    # from happening in the results, because
    #   + S channel has division by `maxc`, which is zero only if `maxc = minc`
    #   + H channel has division by `(maxc - minc)`.
    #
    # Instead of overwriting NaN afterwards, we just prevent it from occuring so
    # we don't need to deal with it in case we save the NaN in a buffer in
    # backprop, if it is ever supported, but it doesn't hurt to do so.
    eqc = maxc == minc
362
363

    cr = maxc - minc
364
365
366
367
368
369
370
371
372
373
    # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
    s = cr / torch.where(eqc, maxc.new_ones(()), maxc)
    # Note that `eqc => maxc = minc = r = g = b`. So the following calculation
    # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
    # would not matter what values `rc`, `gc`, and `bc` have here, and thus
    # replacing denominator with 1 when `eqc` is fine.
    cr_divisor = torch.where(eqc, maxc.new_ones(()), cr)
    rc = (maxc - r) / cr_divisor
    gc = (maxc - g) / cr_divisor
    bc = (maxc - b) / cr_divisor
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401

    hr = (maxc == r) * (bc - gc)
    hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
    hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
    h = (hr + hg + hb)
    h = torch.fmod((h / 6.0 + 1.0), 1.0)
    return torch.stack((h, s, maxc))


def _hsv2rgb(img):
    h, s, v = img.unbind(0)
    i = torch.floor(h * 6.0)
    f = (h * 6.0) - i
    i = i.to(dtype=torch.int32)

    p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
    q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
    t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
    i = i % 6

    mask = i == torch.arange(6)[:, None, None]

    a1 = torch.stack((v, q, p, p, t, v))
    a2 = torch.stack((t, v, v, q, p, p))
    a3 = torch.stack((p, p, t, v, v, q))
    a4 = torch.stack((a1, a2, a3))

    return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4)
402
403


404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
    # padding is left, right, top, bottom
    in_sizes = img.size()

    x_indices = [i for i in range(in_sizes[-1])]  # [0, 1, 2, 3, ...]
    left_indices = [i for i in range(padding[0] - 1, -1, -1)]  # e.g. [3, 2, 1, 0]
    right_indices = [-(i + 1) for i in range(padding[1])]  # e.g. [-1, -2, -3]
    x_indices = torch.tensor(left_indices + x_indices + right_indices)

    y_indices = [i for i in range(in_sizes[-2])]
    top_indices = [i for i in range(padding[2] - 1, -1, -1)]
    bottom_indices = [-(i + 1) for i in range(padding[3])]
    y_indices = torch.tensor(top_indices + y_indices + bottom_indices)

    ndim = img.ndim
    if ndim == 3:
        return img[:, y_indices[:, None], x_indices[None, :]]
    elif ndim == 4:
        return img[:, :, y_indices[:, None], x_indices[None, :]]
    else:
        raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")


427
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
428
429
430
431
432
433
434
435
436
437
438
439
    r"""Pad the given Tensor Image on all sides with specified padding mode and fill value.

    Args:
        img (Tensor): Image to be padded.
        padding (int or tuple or list): Padding on each border. If a single int is provided this
            is used to pad all borders. If a tuple or list of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple or list of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively. In torchscript mode padding as single int is not supported, use a tuple or
            list of length 1: ``[padding, ]``.
        fill (int): Pixel fill value for constant fill. Default is 0.
            This value is only used when the padding_mode is constant
vfdev's avatar
vfdev committed
440
441
        padding_mode (str): Type of padding. Should be: constant, edge or reflect. Default is constant.
            Mode symmetric is not yet supported for Tensor inputs.
442
443
444

            - constant: pads with a constant value, this value is specified with fill

445
446
447
448
449
450
451
            - edge: pads with the last value on the edge of the image

            - reflect: pads with reflection of image (without repeating the last value on the edge)

                       padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
                       will result in [3, 2, 1, 2, 3, 4, 3, 2]

452
453
454
455
456
            - symmetric: pads with reflection of image (repeating the last value on the edge)

                         padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
                         will result in [2, 1, 1, 2, 3, 4, 4, 3]

457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    Returns:
        Tensor: Padded image.
    """
    if not _is_tensor_a_torch_image(img):
        raise TypeError("tensor is not a torch image.")

    if not isinstance(padding, (int, tuple, list)):
        raise TypeError("Got inappropriate padding arg")
    if not isinstance(fill, (int, float)):
        raise TypeError("Got inappropriate fill arg")
    if not isinstance(padding_mode, str):
        raise TypeError("Got inappropriate padding_mode arg")

    if isinstance(padding, tuple):
        padding = list(padding)

    if isinstance(padding, list) and len(padding) not in [1, 2, 4]:
        raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
                         "{} element tuple".format(len(padding)))

477
478
    if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
        raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
479
480
481

    if isinstance(padding, int):
        if torch.jit.is_scripting():
vfdev's avatar
vfdev committed
482
            # This maybe unreachable
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
            raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
        pad_left = pad_right = pad_top = pad_bottom = padding
    elif len(padding) == 1:
        pad_left = pad_right = pad_top = pad_bottom = padding[0]
    elif len(padding) == 2:
        pad_left = pad_right = padding[0]
        pad_top = pad_bottom = padding[1]
    else:
        pad_left = padding[0]
        pad_top = padding[1]
        pad_right = padding[2]
        pad_bottom = padding[3]

    p = [pad_left, pad_right, pad_top, pad_bottom]

498
499
500
    if padding_mode == "edge":
        # remap padding_mode str
        padding_mode = "replicate"
501
502
503
504
505
    elif padding_mode == "symmetric":
        # route to another implementation
        if p[0] < 0 or p[1] < 0 or p[2] < 0 or p[3] < 0:  # no any support for torch script
            raise ValueError("Padding can not be negative for symmetric padding_mode")
        return _pad_symmetric(img, p)
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520

    need_squeeze = False
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    out_dtype = img.dtype
    need_cast = False
    if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
        # Here we temporary cast input tensor to float
        # until pytorch issue is resolved :
        # https://github.com/pytorch/pytorch/issues/40763
        need_cast = True
        img = img.to(torch.float32)

521
    img = torch.nn.functional.pad(img, p, mode=padding_mode, value=float(fill))
522
523
524
525
526
527
528

    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
        img = img.to(out_dtype)

529
    return img
vfdev's avatar
vfdev committed
530
531
532
533
534
535
536
537
538
539
540
541
542
543


def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
    r"""Resize the input Tensor to the given size.

    Args:
        img (Tensor): Image to be resized.
        size (int or tuple or list): Desired output size. If size is a sequence like
            (h, w), the output size will be matched to this. If size is an int,
            the smaller edge of the image will be matched to this number maintaining
            the aspect ratio. i.e, if height > width, then image will be rescaled to
            :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
            In torchscript mode padding as a single int is not supported, use a tuple or
            list of length 1: ``[size, ]``.
vfdev's avatar
vfdev committed
544
545
        interpolation (int, optional): Desired interpolation. Default is bilinear (=2). Other supported values:
            nearest(=0) and bicubic(=3).
vfdev's avatar
vfdev committed
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580

    Returns:
        Tensor: Resized image.
    """
    if not _is_tensor_a_torch_image(img):
        raise TypeError("tensor is not a torch image.")

    if not isinstance(size, (int, tuple, list)):
        raise TypeError("Got inappropriate size arg")
    if not isinstance(interpolation, int):
        raise TypeError("Got inappropriate interpolation arg")

    _interpolation_modes = {
        0: "nearest",
        2: "bilinear",
        3: "bicubic",
    }

    if interpolation not in _interpolation_modes:
        raise ValueError("This interpolation mode is unsupported with Tensor input")

    if isinstance(size, tuple):
        size = list(size)

    if isinstance(size, list) and len(size) not in [1, 2]:
        raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a "
                         "{} element tuple/list".format(len(size)))

    w, h = _get_image_size(img)

    if isinstance(size, int):
        size_w, size_h = size, size
    elif len(size) < 2:
        size_w, size_h = size[0], size[0]
    else:
581
        size_w, size_h = size[1], size[0]  # Convention (h, w)
vfdev's avatar
vfdev committed
582
583
584
585
586
587
588

    if isinstance(size, int) or len(size) < 2:
        if w < h:
            size_h = int(size_w * h / w)
        else:
            size_w = int(size_h * w / h)

589
590
        if (w <= h and w == size_w) or (h <= w and h == size_h):
            return img
vfdev's avatar
vfdev committed
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619

    # make image NCHW
    need_squeeze = False
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    mode = _interpolation_modes[interpolation]

    out_dtype = img.dtype
    need_cast = False
    if img.dtype not in (torch.float32, torch.float64):
        need_cast = True
        img = img.to(torch.float32)

    # Define align_corners to avoid warnings
    align_corners = False if mode in ["bilinear", "bicubic"] else None

    img = torch.nn.functional.interpolate(img, size=(size_h, size_w), mode=mode, align_corners=align_corners)

    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
        if mode == "bicubic":
            img = img.clamp(min=0, max=255)
        img = img.to(out_dtype)

    return img
vfdev's avatar
vfdev committed
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679


def affine(
        img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None
) -> Tensor:
    """Apply affine transformation on the Tensor image keeping image center invariant.

    Args:
        img (Tensor): image to be rotated.
        matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation.
        resample (int, optional): An optional resampling filter. Default is nearest (=2). Other supported values:
            bilinear(=2).
        fillcolor (int, optional): this option is not supported for Tensor input. Fill value for the area outside the
            transform in the output image is always 0.

    Returns:
        Tensor: Transformed image.
    """
    if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
        raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))

    if fillcolor is not None:
        warnings.warn("Argument fillcolor is not supported for Tensor input. Fill value is zero")

    _interpolation_modes = {
        0: "nearest",
        2: "bilinear",
    }

    if resample not in _interpolation_modes:
        raise ValueError("This resampling mode is unsupported with Tensor input")

    theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3)
    shape = img.shape
    grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False)

    # make image NCHW
    need_squeeze = False
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    mode = _interpolation_modes[resample]

    out_dtype = img.dtype
    need_cast = False
    if img.dtype not in (torch.float32, torch.float64):
        need_cast = True
        img = img.to(torch.float32)

    img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)

    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
        # it is better to round before cast
        img = torch.round(img).to(out_dtype)

    return img