functional.py 65.2 KB
Newer Older
1
import math
2
3
import numbers
import warnings
4
from enum import Enum
5
from typing import Any, List, Optional, Tuple, Union
6
7
8

import numpy as np
import torch
9
from PIL import Image
10
11
from torch import Tensor

12
13
14
15
16
try:
    import accimage
except ImportError:
    accimage = None

17
from ..utils import _log_api_usage_once
18
from . import _functional_pil as F_pil, _functional_tensor as F_t
19

20

21
class InterpolationMode(Enum):
22
    """Interpolation modes
23
24
    Available interpolation methods are ``nearest``, ``nearest-exact``, ``bilinear``, ``bicubic``, ``box``, ``hamming``,
    and ``lanczos``.
25
    """
26

27
    NEAREST = "nearest"
28
    NEAREST_EXACT = "nearest-exact"
29
30
31
32
33
34
35
36
37
    BILINEAR = "bilinear"
    BICUBIC = "bicubic"
    # For PIL compatibility
    BOX = "box"
    HAMMING = "hamming"
    LANCZOS = "lanczos"


# TODO: Once torchscript supports Enums with staticmethod
38
39
# this can be put into InterpolationMode as staticmethod
def _interpolation_modes_from_int(i: int) -> InterpolationMode:
40
    inverse_modes_mapping = {
41
42
43
44
45
46
        0: InterpolationMode.NEAREST,
        2: InterpolationMode.BILINEAR,
        3: InterpolationMode.BICUBIC,
        4: InterpolationMode.BOX,
        5: InterpolationMode.HAMMING,
        1: InterpolationMode.LANCZOS,
47
48
49
50
51
    }
    return inverse_modes_mapping[i]


pil_modes_mapping = {
52
53
54
    InterpolationMode.NEAREST: 0,
    InterpolationMode.BILINEAR: 2,
    InterpolationMode.BICUBIC: 3,
55
    InterpolationMode.NEAREST_EXACT: 0,
56
57
58
    InterpolationMode.BOX: 4,
    InterpolationMode.HAMMING: 5,
    InterpolationMode.LANCZOS: 1,
59
60
}

vfdev's avatar
vfdev committed
61
62
63
_is_pil_image = F_pil._is_pil_image


64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def get_dimensions(img: Tensor) -> List[int]:
    """Returns the dimensions of an image as [channels, height, width].

    Args:
        img (PIL Image or Tensor): The image to be checked.

    Returns:
        List[int]: The image dimensions.
    """
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(get_dimensions)
    if isinstance(img, torch.Tensor):
        return F_t.get_dimensions(img)

    return F_pil.get_dimensions(img)


81
82
83
84
85
86
87
88
def get_image_size(img: Tensor) -> List[int]:
    """Returns the size of an image as [width, height].

    Args:
        img (PIL Image or Tensor): The image to be checked.

    Returns:
        List[int]: The image size.
vfdev's avatar
vfdev committed
89
    """
90
91
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(get_image_size)
vfdev's avatar
vfdev committed
92
    if isinstance(img, torch.Tensor):
93
        return F_t.get_image_size(img)
94

95
    return F_pil.get_image_size(img)
96

vfdev's avatar
vfdev committed
97

98
99
100
101
102
103
104
105
def get_image_num_channels(img: Tensor) -> int:
    """Returns the number of channels of an image.

    Args:
        img (PIL Image or Tensor): The image to be checked.

    Returns:
        int: The number of channels.
106
    """
107
108
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(get_image_num_channels)
109
    if isinstance(img, torch.Tensor):
110
        return F_t.get_image_num_channels(img)
111

112
    return F_pil.get_image_num_channels(img)
113
114


vfdev's avatar
vfdev committed
115
116
@torch.jit.unused
def _is_numpy(img: Any) -> bool:
117
118
119
    return isinstance(img, np.ndarray)


vfdev's avatar
vfdev committed
120
121
@torch.jit.unused
def _is_numpy_image(img: Any) -> bool:
122
    return img.ndim in {2, 3}
123
124


125
def to_tensor(pic) -> Tensor:
126
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
127
    This function does not support torchscript.
128

129
    See :class:`~torchvision.transforms.ToTensor` for more details.
130
131
132
133
134
135
136

    Args:
        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
137
138
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(to_tensor)
139
    if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
140
        raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}")
141

142
    if _is_numpy(pic) and not _is_numpy_image(pic):
143
        raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
144

145
146
    default_float_dtype = torch.get_default_dtype()

147
148
    if isinstance(pic, np.ndarray):
        # handle numpy array
surgan12's avatar
surgan12 committed
149
150
151
        if pic.ndim == 2:
            pic = pic[:, :, None]

152
        img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
153
        # backward compatibility
154
        if isinstance(img, torch.ByteTensor):
155
            return img.to(dtype=default_float_dtype).div(255)
156
157
        else:
            return img
158
159

    if accimage is not None and isinstance(pic, accimage.Image):
160
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
161
        pic.copyto(nppic)
162
        return torch.from_numpy(nppic).to(dtype=default_float_dtype)
163
164

    # handle PIL Image
165
166
    mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32}
    img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
167

168
    if pic.mode == "1":
169
        img = 255 * img
170
    img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
171
    # put it from HWC to CHW format
172
    img = img.permute((2, 0, 1)).contiguous()
173
    if isinstance(img, torch.ByteTensor):
174
        return img.to(dtype=default_float_dtype).div(255)
175
176
177
178
    else:
        return img


179
def pil_to_tensor(pic: Any) -> Tensor:
180
    """Convert a ``PIL Image`` to a tensor of the same type.
181
    This function does not support torchscript.
182

vfdev's avatar
vfdev committed
183
    See :class:`~torchvision.transforms.PILToTensor` for more details.
184

185
186
187
188
    .. note::

        A deep copy of the underlying array is performed.

189
190
191
192
193
194
    Args:
        pic (PIL Image): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
195
196
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(pil_to_tensor)
197
    if not F_pil._is_pil_image(pic):
198
        raise TypeError(f"pic should be PIL Image. Got {type(pic)}")
199
200

    if accimage is not None and isinstance(pic, accimage.Image):
201
202
        # accimage format is always uint8 internally, so always return uint8 here
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.uint8)
203
204
205
206
        pic.copyto(nppic)
        return torch.as_tensor(nppic)

    # handle PIL Image
207
    img = torch.as_tensor(np.array(pic, copy=True))
208
    img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
209
210
211
212
213
    # put it from HWC to CHW format
    img = img.permute((2, 0, 1))
    return img


214
215
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
    """Convert a tensor image to the given ``dtype`` and scale the values accordingly
216
    This function does not support PIL Image.
217
218
219
220
221
222

    Args:
        image (torch.Tensor): Image to be converted
        dtype (torch.dtype): Desired data type of the output

    Returns:
vfdev's avatar
vfdev committed
223
        Tensor: Converted image
224
225
226
227
228
229
230
231
232
233
234
235

    .. note::

        When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
        If converted back and forth, this mismatch has no effect.

    Raises:
        RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
            well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
            overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
            of the integer ``dtype``.
    """
236
237
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(convert_image_dtype)
238
    if not isinstance(image, torch.Tensor):
239
        raise TypeError("Input img should be Tensor Image")
240
241

    return F_t.convert_image_dtype(image, dtype)
242
243


244
def to_pil_image(pic, mode=None):
245
    """Convert a tensor or an ndarray to PIL Image. This function does not support torchscript.
246

247
    See :class:`~torchvision.transforms.ToPILImage` for more details.
248
249
250
251
252

    Args:
        pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).

253
    .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
254
255
256
257

    Returns:
        PIL Image: Image converted to PIL Image.
    """
258
259
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(to_pil_image)
260

261
262
263
264
265
    if isinstance(pic, torch.Tensor):
        if pic.ndim == 3:
            pic = pic.permute((1, 2, 0))
        pic = pic.numpy(force=True)
    elif not isinstance(pic, np.ndarray):
266
        raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
267

268
269
270
271
272
    if pic.ndim == 2:
        # if 2D image, add channel dimension (HWC)
        pic = np.expand_dims(pic, 2)
    if pic.ndim != 3:
        raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
Varun Agrawal's avatar
Varun Agrawal committed
273

274
275
    if pic.shape[-1] > 4:
        raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")
276

277
278
    npimg = pic

279
280
    if np.issubdtype(npimg.dtype, np.floating) and mode != "F":
        npimg = (npimg * 255).astype(np.uint8)
281
282
283
284
285

    if npimg.shape[2] == 1:
        expected_mode = None
        npimg = npimg[:, :, 0]
        if npimg.dtype == np.uint8:
286
            expected_mode = "L"
vfdev's avatar
vfdev committed
287
        elif npimg.dtype == np.int16:
288
            expected_mode = "I;16"
vfdev's avatar
vfdev committed
289
        elif npimg.dtype == np.int32:
290
            expected_mode = "I"
291
        elif npimg.dtype == np.float32:
292
            expected_mode = "F"
293
        if mode is not None and mode != expected_mode:
294
            raise ValueError(f"Incorrect mode ({mode}) supplied for input type {np.dtype}. Should be {expected_mode}")
295
296
        mode = expected_mode

surgan12's avatar
surgan12 committed
297
    elif npimg.shape[2] == 2:
298
        permitted_2_channel_modes = ["LA"]
surgan12's avatar
surgan12 committed
299
        if mode is not None and mode not in permitted_2_channel_modes:
300
            raise ValueError(f"Only modes {permitted_2_channel_modes} are supported for 2D inputs")
surgan12's avatar
surgan12 committed
301
302

        if mode is None and npimg.dtype == np.uint8:
303
            mode = "LA"
surgan12's avatar
surgan12 committed
304

305
    elif npimg.shape[2] == 4:
306
        permitted_4_channel_modes = ["RGBA", "CMYK", "RGBX"]
307
        if mode is not None and mode not in permitted_4_channel_modes:
308
            raise ValueError(f"Only modes {permitted_4_channel_modes} are supported for 4D inputs")
309
310

        if mode is None and npimg.dtype == np.uint8:
311
            mode = "RGBA"
312
    else:
313
        permitted_3_channel_modes = ["RGB", "YCbCr", "HSV"]
314
        if mode is not None and mode not in permitted_3_channel_modes:
315
            raise ValueError(f"Only modes {permitted_3_channel_modes} are supported for 3D inputs")
316
        if mode is None and npimg.dtype == np.uint8:
317
            mode = "RGB"
318
319

    if mode is None:
320
        raise TypeError(f"Input type {npimg.dtype} is not supported")
321
322
323
324

    return Image.fromarray(npimg, mode=mode)


325
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
326
    """Normalize a float tensor image with mean and standard deviation.
327
    This transform does not support PIL Image.
328

329
    .. note::
surgan12's avatar
surgan12 committed
330
        This transform acts out of place by default, i.e., it does not mutates the input tensor.
331

332
    See :class:`~torchvision.transforms.Normalize` for more details.
333
334

    Args:
335
        tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
336
        mean (sequence): Sequence of means for each channel.
337
        std (sequence): Sequence of standard deviations for each channel.
338
        inplace(bool,optional): Bool to make this operation inplace.
339
340
341
342

    Returns:
        Tensor: Normalized Tensor image.
    """
343
344
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(normalize)
345
    if not isinstance(tensor, torch.Tensor):
346
        raise TypeError(f"img should be Tensor Image. Got {type(tensor)}")
347

348
    return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
349
350


vfdev's avatar
vfdev committed
351
352
353
def _compute_resized_output_size(
    image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> List[int]:
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    if len(size) == 1:  # specified size only for the smallest edge
        h, w = image_size
        short, long = (w, h) if w <= h else (h, w)
        requested_new_short = size if isinstance(size, int) else size[0]

        new_short, new_long = requested_new_short, int(requested_new_short * long / short)

        if max_size is not None:
            if max_size <= requested_new_short:
                raise ValueError(
                    f"max_size = {max_size} must be strictly greater than the requested "
                    f"size for the smaller edge size = {size}"
                )
            if new_long > max_size:
                new_short, new_long = int(max_size * new_short / new_long), max_size

        new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
    else:  # specified both h and w
        new_w, new_h = size[1], size[0]
    return [new_h, new_w]


376
377
378
379
380
def resize(
    img: Tensor,
    size: List[int],
    interpolation: InterpolationMode = InterpolationMode.BILINEAR,
    max_size: Optional[int] = None,
381
    antialias: Optional[bool] = True,
382
) -> Tensor:
vfdev's avatar
vfdev committed
383
    r"""Resize the input image to the given size.
384
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
385
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
386
387

    Args:
vfdev's avatar
vfdev committed
388
        img (PIL Image or Tensor): Image to be resized.
389
390
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), the output size will be matched to this. If size is an int,
Vitaliy Chiley's avatar
Vitaliy Chiley committed
391
            the smaller edge of the image will be matched to this number maintaining
392
            the aspect ratio. i.e, if height > width, then image will be rescaled to
vfdev's avatar
vfdev committed
393
            :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
394
395
396

            .. note::
                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
397
398
399
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`.
            Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
400
401
            ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
            supported.
402
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
403
        max_size (int, optional): The maximum allowed for the longer edge of
404
            the resized image. If the longer edge of the image is greater
Nicolas Hug's avatar
Nicolas Hug committed
405
            than ``max_size`` after being resized according to ``size``,
406
407
            ``size`` will be overruled so that the longer edge is equal to
            ``max_size``.
Nicolas Hug's avatar
Nicolas Hug committed
408
            As a result, the smaller edge may be shorter than ``size``. This
409
410
            is only supported if ``size`` is an int (or a sequence of length
            1 in torchscript mode).
411
412
413
414
415
416
417
        antialias (bool, optional): Whether to apply antialiasing.
            It only affects **tensors** with bilinear or bicubic modes and it is
            ignored otherwise: on PIL images, antialiasing is always applied on
            bilinear or bicubic modes; on other modes (for PIL images and
            tensors), antialiasing makes no sense and this parameter is ignored.
            Possible values are:

418
            - ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
419
420
421
422
423
424
425
426
              Other mode aren't affected. This is probably what you want to use.
            - ``False``: will not apply antialiasing for tensors on any mode. PIL
              images are still antialiased on bilinear or bicubic modes, because
              PIL doesn't support no antialias.
            - ``None``: equivalent to ``False`` for tensors and ``True`` for
              PIL images. This value exists for legacy reasons and you probably
              don't want to use it unless you really know what you are doing.

427
428
            The default value changed from ``None`` to ``True`` in
            v0.17, for the PIL and Tensor backends to be consistent.
429
430

    Returns:
vfdev's avatar
vfdev committed
431
        PIL Image or Tensor: Resized image.
432
    """
433
434
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(resize)
435

436
437
438
439
440
441
    if isinstance(interpolation, int):
        interpolation = _interpolation_modes_from_int(interpolation)
    elif not isinstance(interpolation, InterpolationMode):
        raise TypeError(
            "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
        )
442

443
444
445
446
447
448
449
450
451
452
453
454
455
456
    if isinstance(size, (list, tuple)):
        if len(size) not in [1, 2]:
            raise ValueError(
                f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list"
            )
        if max_size is not None and len(size) != 1:
            raise ValueError(
                "max_size should only be passed if size specifies the length of the smaller edge, "
                "i.e. size should be an int or a sequence of length 1 in torchscript mode."
            )

    _, image_height, image_width = get_dimensions(img)
    if isinstance(size, int):
        size = [size]
vfdev's avatar
vfdev committed
457
    output_size = _compute_resized_output_size((image_height, image_width), size, max_size)
458

459
    if [image_height, image_width] == output_size:
460
461
        return img

vfdev's avatar
vfdev committed
462
    if not isinstance(img, torch.Tensor):
463
        if antialias is False:
464
            warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
465
        pil_interpolation = pil_modes_mapping[interpolation]
466
        return F_pil.resize(img, size=output_size, interpolation=pil_interpolation)
vfdev's avatar
vfdev committed
467

468
    return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias)
469
470


471
def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor:
472
    r"""Pad the given image on all sides with the given "pad" value.
473
    If the image is torch Tensor, it is expected
474
475
476
    to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
    at most 3 leading dimensions for mode edge,
    and an arbitrary number of leading dimensions for mode constant
477
478

    Args:
479
        img (PIL Image or Tensor): Image to be padded.
480
481
482
        padding (int or sequence): Padding on each border. If a single int is provided this
            is used to pad all borders. If sequence of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a sequence of length 4 is provided
483
            this is the padding for the left, top, right and bottom borders respectively.
484
485
486
487

            .. note::
                In torchscript mode padding as single int is not supported, use a sequence of
                length 1: ``[padding, ]``.
488
        fill (number or tuple): Pixel fill value for constant fill. Default is 0.
489
490
491
            If a tuple of length 3, it is used to fill R, G, B channels respectively.
            This value is only used when the padding_mode is constant.
            Only number is supported for torch Tensor.
492
            Only int or tuple value is supported for PIL Image.
493
494
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
            Default is constant.
495
496
497

            - constant: pads with a constant value, this value is specified with fill

498
499
            - edge: pads with the last value at the edge of the image.
              If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
500

501
502
503
            - reflect: pads with reflection of image without repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
              will result in [3, 2, 1, 2, 3, 4, 3, 2]
504

505
506
507
            - symmetric: pads with reflection of image repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
              will result in [2, 1, 1, 2, 3, 4, 4, 3]
508
509

    Returns:
510
        PIL Image or Tensor: Padded image.
511
    """
512
513
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(pad)
514
515
    if not isinstance(img, torch.Tensor):
        return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
516

517
    return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
518
519


vfdev's avatar
vfdev committed
520
521
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
    """Crop the given image at specified location and output size.
522
    If the image is torch Tensor, it is expected
523
524
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
    If image size is smaller than output size along any edge, image is padded with 0 and then cropped.
525

526
    Args:
vfdev's avatar
vfdev committed
527
        img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
528
529
530
531
        top (int): Vertical component of the top left corner of the crop box.
        left (int): Horizontal component of the top left corner of the crop box.
        height (int): Height of the crop box.
        width (int): Width of the crop box.
532

533
    Returns:
vfdev's avatar
vfdev committed
534
        PIL Image or Tensor: Cropped image.
535
536
    """

537
538
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(crop)
vfdev's avatar
vfdev committed
539
540
    if not isinstance(img, torch.Tensor):
        return F_pil.crop(img, top, left, height, width)
541

vfdev's avatar
vfdev committed
542
    return F_t.crop(img, top, left, height, width)
543

vfdev's avatar
vfdev committed
544
545
546

def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
    """Crops the given image at the center.
547
    If the image is torch Tensor, it is expected
548
549
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
550

551
    Args:
vfdev's avatar
vfdev committed
552
        img (PIL Image or Tensor): Image to be cropped.
553
        output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
vfdev's avatar
vfdev committed
554
555
            it is used for both directions.

556
    Returns:
vfdev's avatar
vfdev committed
557
        PIL Image or Tensor: Cropped image.
558
    """
559
560
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(center_crop)
561
562
    if isinstance(output_size, numbers.Number):
        output_size = (int(output_size), int(output_size))
vfdev's avatar
vfdev committed
563
564
565
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
        output_size = (output_size[0], output_size[0])

566
    _, image_height, image_width = get_dimensions(img)
567
    crop_height, crop_width = output_size
vfdev's avatar
vfdev committed
568

569
570
571
572
573
574
575
576
    if crop_width > image_width or crop_height > image_height:
        padding_ltrb = [
            (crop_width - image_width) // 2 if crop_width > image_width else 0,
            (crop_height - image_height) // 2 if crop_height > image_height else 0,
            (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
            (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
        ]
        img = pad(img, padding_ltrb, fill=0)  # PIL uses fill value 0
577
        _, image_height, image_width = get_dimensions(img)
578
579
580
        if crop_width == image_width and crop_height == image_height:
            return img

581
582
    crop_top = int(round((image_height - crop_height) / 2.0))
    crop_left = int(round((image_width - crop_width) / 2.0))
583
    return crop(img, crop_top, crop_left, crop_height, crop_width)
584
585


586
def resized_crop(
587
588
589
590
591
592
593
    img: Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
    interpolation: InterpolationMode = InterpolationMode.BILINEAR,
594
    antialias: Optional[bool] = True,
595
596
) -> Tensor:
    """Crop the given image and resize it to desired size.
597
    If the image is torch Tensor, it is expected
598
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
599

600
    Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
601
602

    Args:
603
        img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
604
605
606
607
        top (int): Vertical component of the top left corner of the crop box.
        left (int): Horizontal component of the top left corner of the crop box.
        height (int): Height of the crop box.
        width (int): Width of the crop box.
608
        size (sequence or int): Desired output size. Same semantics as ``resize``.
609
610
611
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`.
            Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
612
613
            ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
            supported.
614
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
615
616
617
618
619
620
621
        antialias (bool, optional): Whether to apply antialiasing.
            It only affects **tensors** with bilinear or bicubic modes and it is
            ignored otherwise: on PIL images, antialiasing is always applied on
            bilinear or bicubic modes; on other modes (for PIL images and
            tensors), antialiasing makes no sense and this parameter is ignored.
            Possible values are:

622
            - ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
623
624
625
626
627
628
629
630
              Other mode aren't affected. This is probably what you want to use.
            - ``False``: will not apply antialiasing for tensors on any mode. PIL
              images are still antialiased on bilinear or bicubic modes, because
              PIL doesn't support no antialias.
            - ``None``: equivalent to ``False`` for tensors and ``True`` for
              PIL images. This value exists for legacy reasons and you probably
              don't want to use it unless you really know what you are doing.

631
632
            The default value changed from ``None`` to ``True`` in
            v0.17, for the PIL and Tensor backends to be consistent.
633
    Returns:
634
        PIL Image or Tensor: Cropped image.
635
    """
636
637
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(resized_crop)
638
    img = crop(img, top, left, height, width)
639
    img = resize(img, size, interpolation, antialias=antialias)
640
641
642
    return img


643
def hflip(img: Tensor) -> Tensor:
644
    """Horizontally flip the given image.
645
646

    Args:
vfdev's avatar
vfdev committed
647
        img (PIL Image or Tensor): Image to be flipped. If img
648
            is a Tensor, it is expected to be in [..., H, W] format,
649
            where ... means it can have an arbitrary number of leading
650
            dimensions.
651
652

    Returns:
vfdev's avatar
vfdev committed
653
        PIL Image or Tensor:  Horizontally flipped image.
654
    """
655
656
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(hflip)
657
658
    if not isinstance(img, torch.Tensor):
        return F_pil.hflip(img)
659

660
    return F_t.hflip(img)
661
662


663
def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[int]]) -> List[float]:
664
665
    """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.

Vitaliy Chiley's avatar
Vitaliy Chiley committed
666
    In Perspective Transform each pixel (x, y) in the original image gets transformed as,
667
668
669
     (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )

    Args:
670
671
672
673
674
        startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
            ``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
        endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
            ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.

675
676
677
    Returns:
        octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
    """
678
679
680
681
682
    a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float)

    for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
        a_matrix[2 * i, :] = torch.tensor([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
        a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
683

684
    b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8)
685
    res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution
686

687
    output: List[float] = res.tolist()
688
    return output
689
690


691
def perspective(
692
693
694
695
696
    img: Tensor,
    startpoints: List[List[int]],
    endpoints: List[List[int]],
    interpolation: InterpolationMode = InterpolationMode.BILINEAR,
    fill: Optional[List[float]] = None,
697
698
) -> Tensor:
    """Perform perspective transform of the given image.
699
    If the image is torch Tensor, it is expected
700
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
701
702

    Args:
703
704
705
706
707
        img (PIL Image or Tensor): Image to be transformed.
        startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
            ``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
        endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
            ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
708
709
710
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
711
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
712
713
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
714
715
716
717

            .. note::
                In torchscript mode single int/float value is not supported, please use a sequence
                of length 1: ``[value, ]``.
718

719
    Returns:
720
        PIL Image or Tensor: transformed Image.
721
    """
722
723
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(perspective)
724

725
    coeffs = _get_perspective_coeffs(startpoints, endpoints)
726

727
728
729
730
731
732
    if isinstance(interpolation, int):
        interpolation = _interpolation_modes_from_int(interpolation)
    elif not isinstance(interpolation, InterpolationMode):
        raise TypeError(
            "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
        )
733

734
    if not isinstance(img, torch.Tensor):
735
736
        pil_interpolation = pil_modes_mapping[interpolation]
        return F_pil.perspective(img, coeffs, interpolation=pil_interpolation, fill=fill)
737

738
    return F_t.perspective(img, coeffs, interpolation=interpolation.value, fill=fill)
739
740


741
def vflip(img: Tensor) -> Tensor:
742
    """Vertically flip the given image.
743
744

    Args:
vfdev's avatar
vfdev committed
745
        img (PIL Image or Tensor): Image to be flipped. If img
746
            is a Tensor, it is expected to be in [..., H, W] format,
747
            where ... means it can have an arbitrary number of leading
748
            dimensions.
749
750

    Returns:
751
        PIL Image or Tensor:  Vertically flipped image.
752
    """
753
754
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(vflip)
755
756
    if not isinstance(img, torch.Tensor):
        return F_pil.vflip(img)
757

758
    return F_t.vflip(img)
759
760


vfdev's avatar
vfdev committed
761
762
def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
    """Crop the given image into four corners and the central crop.
763
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
764
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
765
766
767
768
769
770

    .. Note::
        This transform returns a tuple of images and there may be a
        mismatch in the number of inputs and targets your ``Dataset`` returns.

    Args:
vfdev's avatar
vfdev committed
771
772
773
        img (PIL Image or Tensor): Image to be cropped.
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
774
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
775

776
    Returns:
777
       tuple: tuple (tl, tr, bl, br, center)
778
       Corresponding top left, top right, bottom left, bottom right and center crop.
779
    """
780
781
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(five_crop)
782
783
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
vfdev's avatar
vfdev committed
784
785
    elif isinstance(size, (tuple, list)) and len(size) == 1:
        size = (size[0], size[0])
786

vfdev's avatar
vfdev committed
787
788
789
    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")

790
    _, image_height, image_width = get_dimensions(img)
791
792
793
794
795
    crop_height, crop_width = size
    if crop_width > image_width or crop_height > image_height:
        msg = "Requested crop size {} is bigger than input size {}"
        raise ValueError(msg.format(size, (image_height, image_width)))

vfdev's avatar
vfdev committed
796
797
798
799
800
801
802
803
    tl = crop(img, 0, 0, crop_height, crop_width)
    tr = crop(img, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop(img, image_height - crop_height, 0, crop_height, crop_width)
    br = crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)

    center = center_crop(img, [crop_height, crop_width])

    return tl, tr, bl, br, center
804
805


Philip Meier's avatar
Philip Meier committed
806
807
808
def ten_crop(
    img: Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
vfdev's avatar
vfdev committed
809
810
    """Generate ten cropped images from the given image.
    Crop the given image into four corners and the central crop plus the
811
    flipped version of these (horizontal flipping is used by default).
812
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
813
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
814
815
816
817
818

    .. Note::
        This transform returns a tuple of images and there may be a
        mismatch in the number of inputs and targets your ``Dataset`` returns.

819
    Args:
vfdev's avatar
vfdev committed
820
        img (PIL Image or Tensor): Image to be cropped.
821
        size (sequence or int): Desired output size of the crop. If size is an
822
            int instead of sequence like (h, w), a square crop (size, size) is
823
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
824
        vertical_flip (bool): Use vertical flipping instead of horizontal
825
826

    Returns:
827
        tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
828
829
        Corresponding top left, top right, bottom left, bottom right and
        center crop and same for the flipped image.
830
    """
831
832
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(ten_crop)
833
834
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
vfdev's avatar
vfdev committed
835
836
837
838
839
    elif isinstance(size, (tuple, list)) and len(size) == 1:
        size = (size[0], size[0])

    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")
840
841
842
843
844
845
846
847
848
849
850
851

    first_five = five_crop(img, size)

    if vertical_flip:
        img = vflip(img)
    else:
        img = hflip(img)

    second_five = five_crop(img, size)
    return first_five + second_five


852
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
853
    """Adjust brightness of an image.
854
855

    Args:
vfdev's avatar
vfdev committed
856
        img (PIL Image or Tensor): Image to be adjusted.
857
858
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
            where ... means it can have an arbitrary number of leading dimensions.
859
        brightness_factor (float):  How much to adjust the brightness. Can be
860
            any non-negative number. 0 gives a black image, 1 gives the
861
862
863
            original image while 2 increases the brightness by a factor of 2.

    Returns:
vfdev's avatar
vfdev committed
864
        PIL Image or Tensor: Brightness adjusted image.
865
    """
866
867
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_brightness)
868
869
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_brightness(img, brightness_factor)
870

871
    return F_t.adjust_brightness(img, brightness_factor)
872
873


874
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
875
    """Adjust contrast of an image.
876
877

    Args:
vfdev's avatar
vfdev committed
878
        img (PIL Image or Tensor): Image to be adjusted.
879
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
880
            where ... means it can have an arbitrary number of leading dimensions.
881
        contrast_factor (float): How much to adjust the contrast. Can be any
882
            non-negative number. 0 gives a solid gray image, 1 gives the
883
884
885
            original image while 2 increases the contrast by a factor of 2.

    Returns:
vfdev's avatar
vfdev committed
886
        PIL Image or Tensor: Contrast adjusted image.
887
    """
888
889
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_contrast)
890
891
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_contrast(img, contrast_factor)
892

893
    return F_t.adjust_contrast(img, contrast_factor)
894
895


896
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
897
898
899
    """Adjust color saturation of an image.

    Args:
vfdev's avatar
vfdev committed
900
        img (PIL Image or Tensor): Image to be adjusted.
901
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
902
            where ... means it can have an arbitrary number of leading dimensions.
903
904
905
906
907
        saturation_factor (float):  How much to adjust the saturation. 0 will
            give a black and white image, 1 will give the original image while
            2 will enhance the saturation by a factor of 2.

    Returns:
vfdev's avatar
vfdev committed
908
        PIL Image or Tensor: Saturation adjusted image.
909
    """
910
911
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_saturation)
912
913
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_saturation(img, saturation_factor)
914

915
    return F_t.adjust_saturation(img, saturation_factor)
916
917


918
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
919
920
921
922
923
924
925
926
927
    """Adjust hue of an image.

    The image hue is adjusted by converting the image to HSV and
    cyclically shifting the intensities in the hue channel (H).
    The image is then converted back to original image mode.

    `hue_factor` is the amount of shift in H channel and must be in the
    interval `[-0.5, 0.5]`.

928
929
930
    See `Hue`_ for more details.

    .. _Hue: https://en.wikipedia.org/wiki/Hue
931
932

    Args:
933
        img (PIL Image or Tensor): Image to be adjusted.
934
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
935
            where ... means it can have an arbitrary number of leading dimensions.
936
            If img is PIL Image mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
937
938
939
            Note: the pixel values of the input image has to be non-negative for conversion to HSV space;
            thus it does not work if you normalize your image to an interval with negative values,
            or use an interpolation that generates negative values before using this function.
940
941
942
943
944
945
946
        hue_factor (float):  How much to shift the hue channel. Should be in
            [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
            HSV space in positive and negative direction respectively.
            0 means no shift. Therefore, both -0.5 and 0.5 will give an image
            with complementary colors while 0 gives the original image.

    Returns:
947
        PIL Image or Tensor: Hue adjusted image.
948
    """
949
950
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_hue)
951
952
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_hue(img, hue_factor)
953

954
    return F_t.adjust_hue(img, hue_factor)
955
956


957
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
958
    r"""Perform gamma correction on an image.
959
960
961
962

    Also known as Power Law Transform. Intensities in RGB mode are adjusted
    based on the following equation:

963
964
965
966
    .. math::
        I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}

    See `Gamma Correction`_ for more details.
967

968
    .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
969
970

    Args:
971
        img (PIL Image or Tensor): PIL Image to be adjusted.
972
973
974
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, modes with transparency (alpha channel) are not supported.
975
976
977
        gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
            gamma larger than 1 make the shadows darker,
            while gamma smaller than 1 make dark regions lighter.
978
        gain (float): The constant multiplier.
979
980
    Returns:
        PIL Image or Tensor: Gamma correction adjusted image.
981
    """
982
983
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_gamma)
984
985
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_gamma(img, gamma, gain)
986

987
    return F_t.adjust_gamma(img, gamma, gain)
988
989


vfdev's avatar
vfdev committed
990
def _get_inverse_affine_matrix(
991
    center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
vfdev's avatar
vfdev committed
992
) -> List[float]:
993
994
    # Helper method to compute inverse matrix for affine transformation

995
996
997
    # Pillow requires inverse affine transformation matrix:
    # Affine matrix is : M = T * C * RotateScaleShear * C^-1
    #
998
999
    # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
    #       C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
1000
1001
1002
    #       RotateScaleShear is rotation with scale and shear matrix
    #
    #       RotateScaleShear(a, s, (sx, sy)) =
1003
    #       = R(a) * S(s) * SHy(sy) * SHx(sx)
1004
    #       = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
1005
    #         [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
1006
1007
1008
1009
1010
    #         [ 0                    , 0                                      , 1 ]
    # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
    # SHx(s) = [1, -tan(s)] and SHy(s) = [1      , 0]
    #          [0, 1      ]              [-tan(s), 1]
    #
1011
    # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1
1012

1013
    rot = math.radians(angle)
1014
1015
    sx = math.radians(shear[0])
    sy = math.radians(shear[1])
1016
1017
1018
1019
1020

    cx, cy = center
    tx, ty = translate

    # RSS without scaling
vfdev's avatar
vfdev committed
1021
1022
1023
1024
    a = math.cos(rot - sy) / math.cos(sy)
    b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
    c = math.sin(rot - sy) / math.cos(sy)
    d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
1025

1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
    if inverted:
        # Inverted rotation matrix with scale and shear
        # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
        matrix = [d, -b, 0.0, -c, a, 0.0]
        matrix = [x / scale for x in matrix]
        # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
        matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
        matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
        # Apply center translation: C * RSS^-1 * C^-1 * T^-1
        matrix[2] += cx
        matrix[5] += cy
    else:
        matrix = [a, b, 0.0, c, d, 0.0]
        matrix = [x * scale for x in matrix]
        # Apply inverse of center translation: RSS * C^-1
        matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy)
        matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy)
        # Apply translation and center : T * C * RSS * C^-1
        matrix[2] += cx + tx
        matrix[5] += cy + ty
1046

vfdev's avatar
vfdev committed
1047
    return matrix
1048

vfdev's avatar
vfdev committed
1049

vfdev's avatar
vfdev committed
1050
def rotate(
1051
1052
1053
1054
1055
1056
    img: Tensor,
    angle: float,
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    expand: bool = False,
    center: Optional[List[int]] = None,
    fill: Optional[List[float]] = None,
vfdev's avatar
vfdev committed
1057
1058
) -> Tensor:
    """Rotate the image by angle.
1059
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
1060
1061
1062
1063
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.

    Args:
        img (PIL Image or Tensor): image to be rotated.
1064
        angle (number): rotation angle value in degrees, counter-clockwise.
1065
1066
1067
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1068
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
vfdev's avatar
vfdev committed
1069
1070
1071
1072
        expand (bool, optional): Optional expansion flag.
            If true, expands the output image to make it large enough to hold the entire rotated image.
            If false or omitted, make the output image the same size as the input image.
            Note that the expand flag assumes rotation around the center and no translation.
1073
        center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
vfdev's avatar
vfdev committed
1074
            Default is the center of the image.
1075
1076
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
1077
1078
1079
1080

            .. note::
                In torchscript mode single int/float value is not supported, please use a sequence
                of length 1: ``[value, ]``.
vfdev's avatar
vfdev committed
1081
1082
1083
1084
1085
1086
    Returns:
        PIL Image or Tensor: Rotated image.

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

    """
1087
1088
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(rotate)
1089

1090
1091
1092
1093
1094
1095
1096
    if isinstance(interpolation, int):
        interpolation = _interpolation_modes_from_int(interpolation)
    elif not isinstance(interpolation, InterpolationMode):
        raise TypeError(
            "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
        )

vfdev's avatar
vfdev committed
1097
1098
1099
1100
1101
1102
1103
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if center is not None and not isinstance(center, (list, tuple)):
        raise TypeError("Argument center should be a sequence")

    if not isinstance(img, torch.Tensor):
1104
1105
        pil_interpolation = pil_modes_mapping[interpolation]
        return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill)
vfdev's avatar
vfdev committed
1106
1107
1108

    center_f = [0.0, 0.0]
    if center is not None:
1109
        _, height, width = get_dimensions(img)
1110
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
1111
        center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
1112

vfdev's avatar
vfdev committed
1113
1114
1115
    # due to current incoherence of rotation angle direction between affine and rotate implementations
    # we need to set -angle.
    matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
1116
    return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill)
vfdev's avatar
vfdev committed
1117
1118


vfdev's avatar
vfdev committed
1119
def affine(
1120
1121
1122
1123
1124
1125
1126
    img: Tensor,
    angle: float,
    translate: List[int],
    scale: float,
    shear: List[float],
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    fill: Optional[List[float]] = None,
1127
    center: Optional[List[int]] = None,
vfdev's avatar
vfdev committed
1128
1129
) -> Tensor:
    """Apply affine transformation on the image keeping image center invariant.
1130
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
1131
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1132
1133

    Args:
vfdev's avatar
vfdev committed
1134
        img (PIL Image or Tensor): image to transform.
1135
1136
        angle (number): rotation angle in degrees between -180 and 180, clockwise direction.
        translate (sequence of integers): horizontal and vertical translations (post-rotation translation)
1137
        scale (float): overall scale
1138
        shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction.
1139
1140
            If a sequence is specified, the first value corresponds to a shear parallel to the x-axis, while
            the second value corresponds to a shear parallel to the y-axis.
1141
1142
1143
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1144
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
1145
1146
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
1147
1148
1149
1150

            .. note::
                In torchscript mode single int/float value is not supported, please use a sequence
                of length 1: ``[value, ]``.
1151
1152
        center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
            Default is the center of the image.
vfdev's avatar
vfdev committed
1153
1154
1155

    Returns:
        PIL Image or Tensor: Transformed image.
1156
    """
1157
1158
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(affine)
1159

1160
1161
1162
1163
1164
1165
1166
    if isinstance(interpolation, int):
        interpolation = _interpolation_modes_from_int(interpolation)
    elif not isinstance(interpolation, InterpolationMode):
        raise TypeError(
            "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
        )

vfdev's avatar
vfdev committed
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if not isinstance(translate, (list, tuple)):
        raise TypeError("Argument translate should be a sequence")

    if len(translate) != 2:
        raise ValueError("Argument translate should be a sequence of length 2")

    if scale <= 0.0:
        raise ValueError("Argument scale should be positive")

    if not isinstance(shear, (numbers.Number, (list, tuple))):
        raise TypeError("Shear should be either a single value or a sequence of two values")

    if isinstance(angle, int):
        angle = float(angle)

    if isinstance(translate, tuple):
        translate = list(translate)

    if isinstance(shear, numbers.Number):
        shear = [shear, 0.0]

    if isinstance(shear, tuple):
        shear = list(shear)

    if len(shear) == 1:
        shear = [shear[0], shear[0]]

    if len(shear) != 2:
1198
        raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
vfdev's avatar
vfdev committed
1199

1200
1201
1202
    if center is not None and not isinstance(center, (list, tuple)):
        raise TypeError("Argument center should be a sequence")

1203
    _, height, width = get_dimensions(img)
vfdev's avatar
vfdev committed
1204
    if not isinstance(img, torch.Tensor):
1205
        # center = (width * 0.5 + 0.5, height * 0.5 + 0.5)
vfdev's avatar
vfdev committed
1206
1207
        # it is visually better to estimate the center without 0.5 offset
        # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
1208
        if center is None:
1209
            center = [width * 0.5, height * 0.5]
vfdev's avatar
vfdev committed
1210
        matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
1211
1212
        pil_interpolation = pil_modes_mapping[interpolation]
        return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill)
1213

1214
1215
    center_f = [0.0, 0.0]
    if center is not None:
1216
        _, height, width = get_dimensions(img)
1217
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
1218
        center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
1219

1220
    translate_f = [1.0 * t for t in translate]
1221
    matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
1222
    return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill)
1223
1224


1225
1226
# Looks like to_grayscale() is a stand-alone functional that is never called
# from the transform classes. Perhaps it's still here for BC? I can't be
1227
# bothered to dig.
1228
@torch.jit.unused
1229
def to_grayscale(img, num_output_channels=1):
1230
    """Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
1231
    This transform does not support torch Tensor.
1232
1233

    Args:
1234
        img (PIL Image): PIL Image to be converted to grayscale.
1235
        num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default is 1.
1236
1237

    Returns:
1238
1239
        PIL Image: Grayscale version of the image.

1240
1241
        - if num_output_channels = 1 : returned image is single channel
        - if num_output_channels = 3 : returned image is 3 channel with r = g = b
1242
    """
1243
1244
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(to_grayscale)
1245
1246
    if isinstance(img, Image.Image):
        return F_pil.to_grayscale(img, num_output_channels)
1247

1248
1249
1250
1251
1252
    raise TypeError("Input should be PIL Image")


def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
    """Convert RGB image to grayscale version of image.
1253
1254
    If the image is torch Tensor, it is expected
    to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1255
1256
1257

    Note:
        Please, note that this method supports only RGB images as input. For inputs in other color spaces,
1258
        please, consider using :meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.
1259
1260
1261
1262
1263
1264
1265
1266

    Args:
        img (PIL Image or Tensor): RGB Image to be converted to grayscale.
        num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.

    Returns:
        PIL Image or Tensor: Grayscale version of the image.

1267
1268
        - if num_output_channels = 1 : returned image is single channel
        - if num_output_channels = 3 : returned image is 3 channel with r = g = b
1269
    """
1270
1271
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(rgb_to_grayscale)
1272
1273
1274
1275
    if not isinstance(img, torch.Tensor):
        return F_pil.to_grayscale(img, num_output_channels)

    return F_t.rgb_to_grayscale(img, num_output_channels)
1276
1277


1278
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
1279
    """Erase the input Tensor Image with given value.
1280
    This transform does not support PIL Image.
1281
1282
1283
1284
1285
1286
1287
1288

    Args:
        img (Tensor Image): Tensor image of size (C, H, W) to be erased
        i (int): i in (i,j) i.e coordinates of the upper left corner.
        j (int): j in (i,j) i.e coordinates of the upper left corner.
        h (int): Height of the erased region.
        w (int): Width of the erased region.
        v: Erasing value.
1289
        inplace(bool, optional): For in-place operations. By default, is set False.
1290
1291
1292
1293

    Returns:
        Tensor Image: Erased image.
    """
1294
1295
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(erase)
1296
    if not isinstance(img, torch.Tensor):
1297
        raise TypeError(f"img should be Tensor Image. Got {type(img)}")
1298

1299
    return F_t.erase(img, i, j, h, w, v, inplace=inplace)
1300
1301
1302


def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor:
1303
1304
    """Performs Gaussian blurring on the image by given kernel.
    If the image is torch Tensor, it is expected
Haochen Yu's avatar
Haochen Yu committed
1305
    to have [..., H, W] shape, where ... means at most one leading dimension.
1306
1307
1308
1309
1310

    Args:
        img (PIL Image or Tensor): Image to be blurred
        kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers
            like ``(kx, ky)`` or a single integer for square kernels.
1311
1312
1313
1314

            .. note::
                In torchscript mode kernel_size as single int is not supported, use a sequence of
                length 1: ``[ksize, ]``.
1315
1316
1317
1318
        sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a
            sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
            same sigma in both X/Y directions. If None, then it is computed using
            ``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``.
1319
1320
1321
1322
1323
            Default, None.

            .. note::
                In torchscript mode sigma as single float is
                not supported, use a sequence of length 1: ``[sigma, ]``.
1324
1325
1326
1327

    Returns:
        PIL Image or Tensor: Gaussian Blurred version of the image.
    """
1328
1329
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(gaussian_blur)
1330
    if not isinstance(kernel_size, (int, list, tuple)):
1331
        raise TypeError(f"kernel_size should be int or a sequence of integers. Got {type(kernel_size)}")
1332
1333
1334
    if isinstance(kernel_size, int):
        kernel_size = [kernel_size, kernel_size]
    if len(kernel_size) != 2:
1335
        raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
1336
1337
    for ksize in kernel_size:
        if ksize % 2 == 0 or ksize < 0:
1338
            raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
1339
1340
1341
1342
1343

    if sigma is None:
        sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]

    if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
1344
        raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
1345
1346
1347
1348
1349
    if isinstance(sigma, (int, float)):
        sigma = [float(sigma), float(sigma)]
    if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
        sigma = [sigma[0], sigma[0]]
    if len(sigma) != 2:
1350
        raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}")
1351
    for s in sigma:
1352
        if s <= 0.0:
1353
            raise ValueError(f"sigma should have positive values. Got {sigma}")
1354
1355
1356
1357

    t_img = img
    if not isinstance(img, torch.Tensor):
        if not F_pil._is_pil_image(img):
1358
            raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
1359

1360
        t_img = pil_to_tensor(img)
1361
1362
1363
1364

    output = F_t.gaussian_blur(t_img, kernel_size, sigma)

    if not isinstance(img, torch.Tensor):
1365
        output = to_pil_image(output, mode=img.mode)
1366
    return output
1367
1368
1369


def invert(img: Tensor) -> Tensor:
1370
    """Invert the colors of an RGB/grayscale image.
1371
1372
1373

    Args:
        img (PIL Image or Tensor): Image to have its colors inverted.
1374
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1375
1376
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "L" or "RGB".
1377
1378
1379
1380

    Returns:
        PIL Image or Tensor: Color inverted image.
    """
1381
1382
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(invert)
1383
1384
1385
1386
1387
1388
1389
    if not isinstance(img, torch.Tensor):
        return F_pil.invert(img)

    return F_t.invert(img)


def posterize(img: Tensor, bits: int) -> Tensor:
1390
    """Posterize an image by reducing the number of bits for each color channel.
1391
1392
1393

    Args:
        img (PIL Image or Tensor): Image to have its colors posterized.
1394
            If img is torch Tensor, it should be of type torch.uint8, and
1395
1396
1397
            it is expected to be in [..., 1 or 3, H, W] format, where ... means
            it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "L" or "RGB".
1398
1399
1400
1401
        bits (int): The number of bits to keep for each channel (0-8).
    Returns:
        PIL Image or Tensor: Posterized image.
    """
1402
1403
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(posterize)
1404
    if not (0 <= bits <= 8):
1405
        raise ValueError(f"The number if bits should be between 0 and 8. Got {bits}")
1406
1407
1408
1409
1410
1411
1412
1413

    if not isinstance(img, torch.Tensor):
        return F_pil.posterize(img, bits)

    return F_t.posterize(img, bits)


def solarize(img: Tensor, threshold: float) -> Tensor:
1414
    """Solarize an RGB/grayscale image by inverting all pixel values above a threshold.
1415
1416
1417

    Args:
        img (PIL Image or Tensor): Image to have its colors inverted.
1418
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1419
1420
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "L" or "RGB".
1421
1422
1423
1424
        threshold (float): All pixels equal or above this value are inverted.
    Returns:
        PIL Image or Tensor: Solarized image.
    """
1425
1426
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(solarize)
1427
1428
1429
1430
1431
1432
1433
    if not isinstance(img, torch.Tensor):
        return F_pil.solarize(img, threshold)

    return F_t.solarize(img, threshold)


def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
1434
    """Adjust the sharpness of an image.
1435
1436
1437

    Args:
        img (PIL Image or Tensor): Image to be adjusted.
1438
1439
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
            where ... means it can have an arbitrary number of leading dimensions.
1440
        sharpness_factor (float):  How much to adjust the sharpness. Can be
1441
            any non-negative number. 0 gives a blurred image, 1 gives the
1442
1443
1444
1445
1446
            original image while 2 increases the sharpness by a factor of 2.

    Returns:
        PIL Image or Tensor: Sharpness adjusted image.
    """
1447
1448
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_sharpness)
1449
1450
1451
1452
1453
1454
1455
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_sharpness(img, sharpness_factor)

    return F_t.adjust_sharpness(img, sharpness_factor)


def autocontrast(img: Tensor) -> Tensor:
1456
    """Maximize contrast of an image by remapping its
1457
1458
1459
1460
1461
    pixels per channel so that the lowest becomes black and the lightest
    becomes white.

    Args:
        img (PIL Image or Tensor): Image on which autocontrast is applied.
1462
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1463
1464
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "L" or "RGB".
1465
1466
1467
1468

    Returns:
        PIL Image or Tensor: An image that was autocontrasted.
    """
1469
1470
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(autocontrast)
1471
1472
1473
1474
1475
1476
1477
    if not isinstance(img, torch.Tensor):
        return F_pil.autocontrast(img)

    return F_t.autocontrast(img)


def equalize(img: Tensor) -> Tensor:
1478
    """Equalize the histogram of an image by applying
1479
1480
1481
1482
1483
    a non-linear mapping to the input in order to create a uniform
    distribution of grayscale values in the output.

    Args:
        img (PIL Image or Tensor): Image on which equalize is applied.
1484
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1485
            where ... means it can have an arbitrary number of leading dimensions.
1486
            The tensor dtype must be ``torch.uint8`` and values are expected to be in ``[0, 255]``.
1487
            If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
1488
1489
1490
1491

    Returns:
        PIL Image or Tensor: An image that was equalized.
    """
1492
1493
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(equalize)
1494
1495
1496
1497
    if not isinstance(img, torch.Tensor):
        return F_pil.equalize(img)

    return F_t.equalize(img)
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521


def elastic_transform(
    img: Tensor,
    displacement: Tensor,
    interpolation: InterpolationMode = InterpolationMode.BILINEAR,
    fill: Optional[List[float]] = None,
) -> Tensor:
    """Transform a tensor image with elastic transformations.
    Given alpha and sigma, it will generate displacement
    vectors for all pixels based on random offsets. Alpha controls the strength
    and sigma controls the smoothness of the displacements.
    The displacements are added to an identity grid and the resulting grid is
    used to grid_sample from the image.

    Applications:
        Randomly transforms the morphology of objects in images and produces a
        see-through-water-like effect.

    Args:
        img (PIL Image or Tensor): Image on which elastic_transform is applied.
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
1522
        displacement (Tensor): The displacement field. Expected shape is [1, H, W, 2].
1523
1524
1525
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`.
            Default is ``InterpolationMode.BILINEAR``.
1526
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
        fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
            If a tuple of length 3, it is used to fill R, G, B channels respectively.
            This value is only used when the padding_mode is constant.
    """
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(elastic_transform)
    # Backward compatibility with integer value
    if isinstance(interpolation, int):
        warnings.warn(
            "Argument interpolation should be of type InterpolationMode instead of int. "
            "Please, use InterpolationMode enum."
        )
        interpolation = _interpolation_modes_from_int(interpolation)

    if not isinstance(displacement, torch.Tensor):
1542
        raise TypeError("Argument displacement should be a Tensor")
1543
1544
1545
1546
1547
1548
1549

    t_img = img
    if not isinstance(img, torch.Tensor):
        if not F_pil._is_pil_image(img):
            raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
        t_img = pil_to_tensor(img)

1550
1551
1552
1553
1554
1555
1556
1557
1558
    shape = t_img.shape
    shape = (1,) + shape[-2:] + (2,)
    if shape != displacement.shape:
        raise ValueError(f"Argument displacement shape should be {shape}, but given {displacement.shape}")

    # TODO: if image shape is [N1, N2, ..., C, H, W] and
    # displacement is [1, H, W, 2] we need to reshape input image
    # such grid_sampler takes internal code for 4D input

1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
    output = F_t.elastic_transform(
        t_img,
        displacement,
        interpolation=interpolation.value,
        fill=fill,
    )

    if not isinstance(img, torch.Tensor):
        output = to_pil_image(output, mode=img.mode)
    return output