functional.py 67.2 KB
Newer Older
1
import math
2
3
import numbers
import warnings
4
from enum import Enum
5
from typing import Any, List, Optional, Tuple, Union
6
7
8

import numpy as np
import torch
9
from PIL import Image
10
11
from torch import Tensor

12
13
14
15
16
try:
    import accimage
except ImportError:
    accimage = None

17
from ..utils import _log_api_usage_once
18
from . import functional_pil as F_pil, functional_tensor as F_t
19

20

21
class InterpolationMode(Enum):
22
    """Interpolation modes
23
    Available interpolation methods are ``nearest``, ``bilinear``, ``bicubic``, ``box``, ``hamming``, and ``lanczos``.
24
    """
25

26
27
28
29
30
31
32
33
34
35
    NEAREST = "nearest"
    BILINEAR = "bilinear"
    BICUBIC = "bicubic"
    # For PIL compatibility
    BOX = "box"
    HAMMING = "hamming"
    LANCZOS = "lanczos"


# TODO: Once torchscript supports Enums with staticmethod
36
37
# this can be put into InterpolationMode as staticmethod
def _interpolation_modes_from_int(i: int) -> InterpolationMode:
38
    inverse_modes_mapping = {
39
40
41
42
43
44
        0: InterpolationMode.NEAREST,
        2: InterpolationMode.BILINEAR,
        3: InterpolationMode.BICUBIC,
        4: InterpolationMode.BOX,
        5: InterpolationMode.HAMMING,
        1: InterpolationMode.LANCZOS,
45
46
47
48
49
    }
    return inverse_modes_mapping[i]


pil_modes_mapping = {
50
51
52
53
54
55
    InterpolationMode.NEAREST: 0,
    InterpolationMode.BILINEAR: 2,
    InterpolationMode.BICUBIC: 3,
    InterpolationMode.BOX: 4,
    InterpolationMode.HAMMING: 5,
    InterpolationMode.LANCZOS: 1,
56
57
}

vfdev's avatar
vfdev committed
58
59
60
_is_pil_image = F_pil._is_pil_image


61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def get_dimensions(img: Tensor) -> List[int]:
    """Returns the dimensions of an image as [channels, height, width].

    Args:
        img (PIL Image or Tensor): The image to be checked.

    Returns:
        List[int]: The image dimensions.
    """
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(get_dimensions)
    if isinstance(img, torch.Tensor):
        return F_t.get_dimensions(img)

    return F_pil.get_dimensions(img)


78
79
80
81
82
83
84
85
def get_image_size(img: Tensor) -> List[int]:
    """Returns the size of an image as [width, height].

    Args:
        img (PIL Image or Tensor): The image to be checked.

    Returns:
        List[int]: The image size.
vfdev's avatar
vfdev committed
86
    """
87
88
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(get_image_size)
vfdev's avatar
vfdev committed
89
    if isinstance(img, torch.Tensor):
90
        return F_t.get_image_size(img)
91

92
    return F_pil.get_image_size(img)
93

vfdev's avatar
vfdev committed
94

95
96
97
98
99
100
101
102
def get_image_num_channels(img: Tensor) -> int:
    """Returns the number of channels of an image.

    Args:
        img (PIL Image or Tensor): The image to be checked.

    Returns:
        int: The number of channels.
103
    """
104
105
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(get_image_num_channels)
106
    if isinstance(img, torch.Tensor):
107
        return F_t.get_image_num_channels(img)
108

109
    return F_pil.get_image_num_channels(img)
110
111


vfdev's avatar
vfdev committed
112
113
@torch.jit.unused
def _is_numpy(img: Any) -> bool:
114
115
116
    return isinstance(img, np.ndarray)


vfdev's avatar
vfdev committed
117
118
@torch.jit.unused
def _is_numpy_image(img: Any) -> bool:
119
    return img.ndim in {2, 3}
120
121


122
def to_tensor(pic) -> Tensor:
123
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
124
    This function does not support torchscript.
125

126
    See :class:`~torchvision.transforms.ToTensor` for more details.
127
128
129
130
131
132
133

    Args:
        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
134
135
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(to_tensor)
136
    if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
137
        raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}")
138

139
    if _is_numpy(pic) and not _is_numpy_image(pic):
140
        raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
141

142
143
    default_float_dtype = torch.get_default_dtype()

144
145
    if isinstance(pic, np.ndarray):
        # handle numpy array
surgan12's avatar
surgan12 committed
146
147
148
        if pic.ndim == 2:
            pic = pic[:, :, None]

149
        img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
150
        # backward compatibility
151
        if isinstance(img, torch.ByteTensor):
152
            return img.to(dtype=default_float_dtype).div(255)
153
154
        else:
            return img
155
156

    if accimage is not None and isinstance(pic, accimage.Image):
157
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
158
        pic.copyto(nppic)
159
        return torch.from_numpy(nppic).to(dtype=default_float_dtype)
160
161

    # handle PIL Image
162
163
    mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32}
    img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
164

165
    if pic.mode == "1":
166
        img = 255 * img
167
    img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
168
    # put it from HWC to CHW format
169
    img = img.permute((2, 0, 1)).contiguous()
170
    if isinstance(img, torch.ByteTensor):
171
        return img.to(dtype=default_float_dtype).div(255)
172
173
174
175
    else:
        return img


176
177
def pil_to_tensor(pic):
    """Convert a ``PIL Image`` to a tensor of the same type.
178
    This function does not support torchscript.
179

vfdev's avatar
vfdev committed
180
    See :class:`~torchvision.transforms.PILToTensor` for more details.
181

182
183
184
185
    .. note::

        A deep copy of the underlying array is performed.

186
187
188
189
190
191
    Args:
        pic (PIL Image): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
192
193
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(pil_to_tensor)
194
    if not F_pil._is_pil_image(pic):
195
        raise TypeError(f"pic should be PIL Image. Got {type(pic)}")
196
197

    if accimage is not None and isinstance(pic, accimage.Image):
198
199
        # accimage format is always uint8 internally, so always return uint8 here
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.uint8)
200
201
202
203
        pic.copyto(nppic)
        return torch.as_tensor(nppic)

    # handle PIL Image
204
    img = torch.as_tensor(np.array(pic, copy=True))
205
206
207
208
209
210
    img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
    # put it from HWC to CHW format
    img = img.permute((2, 0, 1))
    return img


211
212
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
    """Convert a tensor image to the given ``dtype`` and scale the values accordingly
213
    This function does not support PIL Image.
214
215
216
217
218
219

    Args:
        image (torch.Tensor): Image to be converted
        dtype (torch.dtype): Desired data type of the output

    Returns:
vfdev's avatar
vfdev committed
220
        Tensor: Converted image
221
222
223
224
225
226
227
228
229
230
231
232

    .. note::

        When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
        If converted back and forth, this mismatch has no effect.

    Raises:
        RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
            well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
            overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
            of the integer ``dtype``.
    """
233
234
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(convert_image_dtype)
235
    if not isinstance(image, torch.Tensor):
236
        raise TypeError("Input img should be Tensor Image")
237
238

    return F_t.convert_image_dtype(image, dtype)
239
240


241
def to_pil_image(pic, mode=None):
242
    """Convert a tensor or an ndarray to PIL Image. This function does not support torchscript.
243

244
    See :class:`~torchvision.transforms.ToPILImage` for more details.
245
246
247
248
249

    Args:
        pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).

250
    .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
251
252
253
254

    Returns:
        PIL Image: Image converted to PIL Image.
    """
255
256
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(to_pil_image)
257
    if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
258
        raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
259

Varun Agrawal's avatar
Varun Agrawal committed
260
261
    elif isinstance(pic, torch.Tensor):
        if pic.ndimension() not in {2, 3}:
262
            raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndimension()} dimensions.")
Varun Agrawal's avatar
Varun Agrawal committed
263
264
265

        elif pic.ndimension() == 2:
            # if 2D image, add channel dimension (CHW)
Surgan Jandial's avatar
Surgan Jandial committed
266
            pic = pic.unsqueeze(0)
Varun Agrawal's avatar
Varun Agrawal committed
267

268
269
        # check number of channels
        if pic.shape[-3] > 4:
270
            raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-3]} channels.")
271

Varun Agrawal's avatar
Varun Agrawal committed
272
273
    elif isinstance(pic, np.ndarray):
        if pic.ndim not in {2, 3}:
274
            raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
Varun Agrawal's avatar
Varun Agrawal committed
275
276
277
278
279

        elif pic.ndim == 2:
            # if 2D image, add channel dimension (HWC)
            pic = np.expand_dims(pic, 2)

280
281
        # check number of channels
        if pic.shape[-1] > 4:
282
            raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")
283

284
    npimg = pic
Varun Agrawal's avatar
Varun Agrawal committed
285
    if isinstance(pic, torch.Tensor):
286
        if pic.is_floating_point() and mode != "F":
287
288
            pic = pic.mul(255).byte()
        npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))
289
290

    if not isinstance(npimg, np.ndarray):
291
        raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, not {type(npimg)}")
292
293
294
295
296

    if npimg.shape[2] == 1:
        expected_mode = None
        npimg = npimg[:, :, 0]
        if npimg.dtype == np.uint8:
297
            expected_mode = "L"
vfdev's avatar
vfdev committed
298
        elif npimg.dtype == np.int16:
299
            expected_mode = "I;16"
vfdev's avatar
vfdev committed
300
        elif npimg.dtype == np.int32:
301
            expected_mode = "I"
302
        elif npimg.dtype == np.float32:
303
            expected_mode = "F"
304
        if mode is not None and mode != expected_mode:
305
            raise ValueError(f"Incorrect mode ({mode}) supplied for input type {np.dtype}. Should be {expected_mode}")
306
307
        mode = expected_mode

surgan12's avatar
surgan12 committed
308
    elif npimg.shape[2] == 2:
309
        permitted_2_channel_modes = ["LA"]
surgan12's avatar
surgan12 committed
310
        if mode is not None and mode not in permitted_2_channel_modes:
311
            raise ValueError(f"Only modes {permitted_2_channel_modes} are supported for 2D inputs")
surgan12's avatar
surgan12 committed
312
313

        if mode is None and npimg.dtype == np.uint8:
314
            mode = "LA"
surgan12's avatar
surgan12 committed
315

316
    elif npimg.shape[2] == 4:
317
        permitted_4_channel_modes = ["RGBA", "CMYK", "RGBX"]
318
        if mode is not None and mode not in permitted_4_channel_modes:
319
            raise ValueError(f"Only modes {permitted_4_channel_modes} are supported for 4D inputs")
320
321

        if mode is None and npimg.dtype == np.uint8:
322
            mode = "RGBA"
323
    else:
324
        permitted_3_channel_modes = ["RGB", "YCbCr", "HSV"]
325
        if mode is not None and mode not in permitted_3_channel_modes:
326
            raise ValueError(f"Only modes {permitted_3_channel_modes} are supported for 3D inputs")
327
        if mode is None and npimg.dtype == np.uint8:
328
            mode = "RGB"
329
330

    if mode is None:
331
        raise TypeError(f"Input type {npimg.dtype} is not supported")
332
333
334
335

    return Image.fromarray(npimg, mode=mode)


336
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
337
    """Normalize a float tensor image with mean and standard deviation.
338
    This transform does not support PIL Image.
339

340
    .. note::
surgan12's avatar
surgan12 committed
341
        This transform acts out of place by default, i.e., it does not mutates the input tensor.
342

343
    See :class:`~torchvision.transforms.Normalize` for more details.
344
345

    Args:
346
        tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
347
        mean (sequence): Sequence of means for each channel.
348
        std (sequence): Sequence of standard deviations for each channel.
349
        inplace(bool,optional): Bool to make this operation inplace.
350
351
352
353

    Returns:
        Tensor: Normalized Tensor image.
    """
354
355
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(normalize)
356
    if not isinstance(tensor, torch.Tensor):
357
        raise TypeError(f"img should be Tensor Image. Got {type(tensor)}")
358

359
    return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
360
361


362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def _compute_output_size(image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None) -> List[int]:
    if len(size) == 1:  # specified size only for the smallest edge
        h, w = image_size
        short, long = (w, h) if w <= h else (h, w)
        requested_new_short = size if isinstance(size, int) else size[0]

        new_short, new_long = requested_new_short, int(requested_new_short * long / short)

        if max_size is not None:
            if max_size <= requested_new_short:
                raise ValueError(
                    f"max_size = {max_size} must be strictly greater than the requested "
                    f"size for the smaller edge size = {size}"
                )
            if new_long > max_size:
                new_short, new_long = int(max_size * new_short / new_long), max_size

        new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
    else:  # specified both h and w
        new_w, new_h = size[1], size[0]
    return [new_h, new_w]


385
386
387
388
389
390
391
def resize(
    img: Tensor,
    size: List[int],
    interpolation: InterpolationMode = InterpolationMode.BILINEAR,
    max_size: Optional[int] = None,
    antialias: Optional[bool] = None,
) -> Tensor:
vfdev's avatar
vfdev committed
392
    r"""Resize the input image to the given size.
393
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
394
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
395

396
397
398
399
    .. warning::
        The output image might be different depending on its type: when downsampling, the interpolation of PIL images
        and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
        in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
400
401
        types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
        closer.
402

403
    Args:
vfdev's avatar
vfdev committed
404
        img (PIL Image or Tensor): Image to be resized.
405
406
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), the output size will be matched to this. If size is an int,
Vitaliy Chiley's avatar
Vitaliy Chiley committed
407
            the smaller edge of the image will be matched to this number maintaining
408
            the aspect ratio. i.e, if height > width, then image will be rescaled to
vfdev's avatar
vfdev committed
409
            :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
410
411
412

            .. note::
                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
413
414
415
416
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`.
            Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
417
418
            For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
            but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
419
420
421
422
        max_size (int, optional): The maximum allowed for the longer edge of
            the resized image: if the longer edge of the image is greater
            than ``max_size`` after being resized according to ``size``, then
            the image is resized again so that the longer edge is equal to
423
            ``max_size``. As a result, ``size`` might be overruled, i.e the
424
425
426
            smaller edge may be shorter than ``size``. This is only supported
            if ``size`` is an int (or a sequence of length 1 in torchscript
            mode).
427
        antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
428
            is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for
429
430
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes.
            This can help making the output for PIL images and tensors closer.
431
432

    Returns:
vfdev's avatar
vfdev committed
433
        PIL Image or Tensor: Resized image.
434
    """
435
436
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(resize)
437
438
439
    # Backward compatibility with integer value
    if isinstance(interpolation, int):
        warnings.warn(
440
441
            "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
            "Please use InterpolationMode enum."
442
443
444
        )
        interpolation = _interpolation_modes_from_int(interpolation)

445
446
    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")
447

448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
    if isinstance(size, (list, tuple)):
        if len(size) not in [1, 2]:
            raise ValueError(
                f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list"
            )
        if max_size is not None and len(size) != 1:
            raise ValueError(
                "max_size should only be passed if size specifies the length of the smaller edge, "
                "i.e. size should be an int or a sequence of length 1 in torchscript mode."
            )

    _, image_height, image_width = get_dimensions(img)
    if isinstance(size, int):
        size = [size]
    output_size = _compute_output_size((image_height, image_width), size, max_size)

    if (image_height, image_width) == output_size:
        return img

vfdev's avatar
vfdev committed
467
    if not isinstance(img, torch.Tensor):
468
        if antialias is not None and not antialias:
469
            warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
470
        pil_interpolation = pil_modes_mapping[interpolation]
471
        return F_pil.resize(img, size=output_size, interpolation=pil_interpolation)
vfdev's avatar
vfdev committed
472

473
    return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias)
474
475


476
def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor:
477
    r"""Pad the given image on all sides with the given "pad" value.
478
    If the image is torch Tensor, it is expected
479
480
481
    to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
    at most 3 leading dimensions for mode edge,
    and an arbitrary number of leading dimensions for mode constant
482
483

    Args:
484
        img (PIL Image or Tensor): Image to be padded.
485
486
487
        padding (int or sequence): Padding on each border. If a single int is provided this
            is used to pad all borders. If sequence of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a sequence of length 4 is provided
488
            this is the padding for the left, top, right and bottom borders respectively.
489
490
491
492

            .. note::
                In torchscript mode padding as single int is not supported, use a sequence of
                length 1: ``[padding, ]``.
493
        fill (number or tuple): Pixel fill value for constant fill. Default is 0.
494
495
496
            If a tuple of length 3, it is used to fill R, G, B channels respectively.
            This value is only used when the padding_mode is constant.
            Only number is supported for torch Tensor.
497
            Only int or tuple value is supported for PIL Image.
498
499
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
            Default is constant.
500
501
502

            - constant: pads with a constant value, this value is specified with fill

503
504
            - edge: pads with the last value at the edge of the image.
              If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
505

506
507
508
            - reflect: pads with reflection of image without repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
              will result in [3, 2, 1, 2, 3, 4, 3, 2]
509

510
511
512
            - symmetric: pads with reflection of image repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
              will result in [2, 1, 1, 2, 3, 4, 4, 3]
513
514

    Returns:
515
        PIL Image or Tensor: Padded image.
516
    """
517
518
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(pad)
519
520
    if not isinstance(img, torch.Tensor):
        return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
521

522
    return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
523
524


vfdev's avatar
vfdev committed
525
526
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
    """Crop the given image at specified location and output size.
527
    If the image is torch Tensor, it is expected
528
529
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
    If image size is smaller than output size along any edge, image is padded with 0 and then cropped.
530

531
    Args:
vfdev's avatar
vfdev committed
532
        img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
533
534
535
536
        top (int): Vertical component of the top left corner of the crop box.
        left (int): Horizontal component of the top left corner of the crop box.
        height (int): Height of the crop box.
        width (int): Width of the crop box.
537

538
    Returns:
vfdev's avatar
vfdev committed
539
        PIL Image or Tensor: Cropped image.
540
541
    """

542
543
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(crop)
vfdev's avatar
vfdev committed
544
545
    if not isinstance(img, torch.Tensor):
        return F_pil.crop(img, top, left, height, width)
546

vfdev's avatar
vfdev committed
547
    return F_t.crop(img, top, left, height, width)
548

vfdev's avatar
vfdev committed
549
550
551

def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
    """Crops the given image at the center.
552
    If the image is torch Tensor, it is expected
553
554
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
555

556
    Args:
vfdev's avatar
vfdev committed
557
        img (PIL Image or Tensor): Image to be cropped.
558
        output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
vfdev's avatar
vfdev committed
559
560
            it is used for both directions.

561
    Returns:
vfdev's avatar
vfdev committed
562
        PIL Image or Tensor: Cropped image.
563
    """
564
565
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(center_crop)
566
567
    if isinstance(output_size, numbers.Number):
        output_size = (int(output_size), int(output_size))
vfdev's avatar
vfdev committed
568
569
570
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
        output_size = (output_size[0], output_size[0])

571
    _, image_height, image_width = get_dimensions(img)
572
    crop_height, crop_width = output_size
vfdev's avatar
vfdev committed
573

574
575
576
577
578
579
580
581
    if crop_width > image_width or crop_height > image_height:
        padding_ltrb = [
            (crop_width - image_width) // 2 if crop_width > image_width else 0,
            (crop_height - image_height) // 2 if crop_height > image_height else 0,
            (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
            (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
        ]
        img = pad(img, padding_ltrb, fill=0)  # PIL uses fill value 0
582
        _, image_height, image_width = get_dimensions(img)
583
584
585
        if crop_width == image_width and crop_height == image_height:
            return img

586
587
    crop_top = int(round((image_height - crop_height) / 2.0))
    crop_left = int(round((image_width - crop_width) / 2.0))
588
    return crop(img, crop_top, crop_left, crop_height, crop_width)
589
590


591
def resized_crop(
592
593
594
595
596
597
598
    img: Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
    interpolation: InterpolationMode = InterpolationMode.BILINEAR,
599
    antialias: Optional[bool] = None,
600
601
) -> Tensor:
    """Crop the given image and resize it to desired size.
602
    If the image is torch Tensor, it is expected
603
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
604

605
    Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
606
607

    Args:
608
        img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
609
610
611
612
        top (int): Vertical component of the top left corner of the crop box.
        left (int): Horizontal component of the top left corner of the crop box.
        height (int): Height of the crop box.
        width (int): Width of the crop box.
613
        size (sequence or int): Desired output size. Same semantics as ``resize``.
614
615
616
617
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`.
            Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
618
619
            For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
            but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
620
621
622
623
        antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
            is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes.
            This can help making the output for PIL images and tensors closer.
624
    Returns:
625
        PIL Image or Tensor: Cropped image.
626
    """
627
628
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(resized_crop)
629
    img = crop(img, top, left, height, width)
630
    img = resize(img, size, interpolation, antialias=antialias)
631
632
633
    return img


634
def hflip(img: Tensor) -> Tensor:
635
    """Horizontally flip the given image.
636
637

    Args:
vfdev's avatar
vfdev committed
638
        img (PIL Image or Tensor): Image to be flipped. If img
639
            is a Tensor, it is expected to be in [..., H, W] format,
640
            where ... means it can have an arbitrary number of leading
641
            dimensions.
642
643

    Returns:
vfdev's avatar
vfdev committed
644
        PIL Image or Tensor:  Horizontally flipped image.
645
    """
646
647
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(hflip)
648
649
    if not isinstance(img, torch.Tensor):
        return F_pil.hflip(img)
650

651
    return F_t.hflip(img)
652
653


654
def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[int]]) -> List[float]:
655
656
    """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.

Vitaliy Chiley's avatar
Vitaliy Chiley committed
657
    In Perspective Transform each pixel (x, y) in the original image gets transformed as,
658
659
660
     (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )

    Args:
661
662
663
664
665
        startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
            ``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
        endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
            ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.

666
667
668
    Returns:
        octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
    """
669
670
671
672
673
    a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float)

    for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
        a_matrix[2 * i, :] = torch.tensor([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
        a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
674

675
    b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8)
676
    res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution
677

678
    output: List[float] = res.tolist()
679
    return output
680
681


682
def perspective(
683
684
685
686
687
    img: Tensor,
    startpoints: List[List[int]],
    endpoints: List[List[int]],
    interpolation: InterpolationMode = InterpolationMode.BILINEAR,
    fill: Optional[List[float]] = None,
688
689
) -> Tensor:
    """Perform perspective transform of the given image.
690
    If the image is torch Tensor, it is expected
691
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
692
693

    Args:
694
695
696
697
698
        img (PIL Image or Tensor): Image to be transformed.
        startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
            ``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
        endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
            ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
699
700
701
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
702
703
            For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
            but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
704
705
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
706
707
708
709

            .. note::
                In torchscript mode single int/float value is not supported, please use a sequence
                of length 1: ``[value, ]``.
710

711
    Returns:
712
        PIL Image or Tensor: transformed Image.
713
    """
714
715
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(perspective)
716

717
    coeffs = _get_perspective_coeffs(startpoints, endpoints)
718

719
720
721
    # Backward compatibility with integer value
    if isinstance(interpolation, int):
        warnings.warn(
722
723
            "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
            "Please use InterpolationMode enum."
724
725
726
        )
        interpolation = _interpolation_modes_from_int(interpolation)

727
728
    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")
729

730
    if not isinstance(img, torch.Tensor):
731
732
        pil_interpolation = pil_modes_mapping[interpolation]
        return F_pil.perspective(img, coeffs, interpolation=pil_interpolation, fill=fill)
733

734
    return F_t.perspective(img, coeffs, interpolation=interpolation.value, fill=fill)
735
736


737
def vflip(img: Tensor) -> Tensor:
738
    """Vertically flip the given image.
739
740

    Args:
vfdev's avatar
vfdev committed
741
        img (PIL Image or Tensor): Image to be flipped. If img
742
            is a Tensor, it is expected to be in [..., H, W] format,
743
            where ... means it can have an arbitrary number of leading
744
            dimensions.
745
746

    Returns:
747
        PIL Image or Tensor:  Vertically flipped image.
748
    """
749
750
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(vflip)
751
752
    if not isinstance(img, torch.Tensor):
        return F_pil.vflip(img)
753

754
    return F_t.vflip(img)
755
756


vfdev's avatar
vfdev committed
757
758
def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
    """Crop the given image into four corners and the central crop.
759
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
760
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
761
762
763
764
765
766

    .. Note::
        This transform returns a tuple of images and there may be a
        mismatch in the number of inputs and targets your ``Dataset`` returns.

    Args:
vfdev's avatar
vfdev committed
767
768
769
        img (PIL Image or Tensor): Image to be cropped.
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
770
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
771

772
    Returns:
773
       tuple: tuple (tl, tr, bl, br, center)
774
       Corresponding top left, top right, bottom left, bottom right and center crop.
775
    """
776
777
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(five_crop)
778
779
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
vfdev's avatar
vfdev committed
780
781
    elif isinstance(size, (tuple, list)) and len(size) == 1:
        size = (size[0], size[0])
782

vfdev's avatar
vfdev committed
783
784
785
    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")

786
    _, image_height, image_width = get_dimensions(img)
787
788
789
790
791
    crop_height, crop_width = size
    if crop_width > image_width or crop_height > image_height:
        msg = "Requested crop size {} is bigger than input size {}"
        raise ValueError(msg.format(size, (image_height, image_width)))

vfdev's avatar
vfdev committed
792
793
794
795
796
797
798
799
    tl = crop(img, 0, 0, crop_height, crop_width)
    tr = crop(img, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop(img, image_height - crop_height, 0, crop_height, crop_width)
    br = crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)

    center = center_crop(img, [crop_height, crop_width])

    return tl, tr, bl, br, center
800
801


vfdev's avatar
vfdev committed
802
803
804
def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[Tensor]:
    """Generate ten cropped images from the given image.
    Crop the given image into four corners and the central crop plus the
805
    flipped version of these (horizontal flipping is used by default).
806
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
807
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
808
809
810
811
812

    .. Note::
        This transform returns a tuple of images and there may be a
        mismatch in the number of inputs and targets your ``Dataset`` returns.

813
    Args:
vfdev's avatar
vfdev committed
814
        img (PIL Image or Tensor): Image to be cropped.
815
        size (sequence or int): Desired output size of the crop. If size is an
816
            int instead of sequence like (h, w), a square crop (size, size) is
817
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
818
        vertical_flip (bool): Use vertical flipping instead of horizontal
819
820

    Returns:
821
        tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
822
823
        Corresponding top left, top right, bottom left, bottom right and
        center crop and same for the flipped image.
824
    """
825
826
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(ten_crop)
827
828
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
vfdev's avatar
vfdev committed
829
830
831
832
833
    elif isinstance(size, (tuple, list)) and len(size) == 1:
        size = (size[0], size[0])

    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")
834
835
836
837
838
839
840
841
842
843
844
845

    first_five = five_crop(img, size)

    if vertical_flip:
        img = vflip(img)
    else:
        img = hflip(img)

    second_five = five_crop(img, size)
    return first_five + second_five


846
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
847
    """Adjust brightness of an image.
848
849

    Args:
vfdev's avatar
vfdev committed
850
        img (PIL Image or Tensor): Image to be adjusted.
851
852
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
            where ... means it can have an arbitrary number of leading dimensions.
853
854
855
856
857
        brightness_factor (float):  How much to adjust the brightness. Can be
            any non negative number. 0 gives a black image, 1 gives the
            original image while 2 increases the brightness by a factor of 2.

    Returns:
vfdev's avatar
vfdev committed
858
        PIL Image or Tensor: Brightness adjusted image.
859
    """
860
861
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_brightness)
862
863
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_brightness(img, brightness_factor)
864

865
    return F_t.adjust_brightness(img, brightness_factor)
866
867


868
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
869
    """Adjust contrast of an image.
870
871

    Args:
vfdev's avatar
vfdev committed
872
        img (PIL Image or Tensor): Image to be adjusted.
873
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
874
            where ... means it can have an arbitrary number of leading dimensions.
875
876
877
878
879
        contrast_factor (float): How much to adjust the contrast. Can be any
            non negative number. 0 gives a solid gray image, 1 gives the
            original image while 2 increases the contrast by a factor of 2.

    Returns:
vfdev's avatar
vfdev committed
880
        PIL Image or Tensor: Contrast adjusted image.
881
    """
882
883
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_contrast)
884
885
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_contrast(img, contrast_factor)
886

887
    return F_t.adjust_contrast(img, contrast_factor)
888
889


890
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
891
892
893
    """Adjust color saturation of an image.

    Args:
vfdev's avatar
vfdev committed
894
        img (PIL Image or Tensor): Image to be adjusted.
895
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
896
            where ... means it can have an arbitrary number of leading dimensions.
897
898
899
900
901
        saturation_factor (float):  How much to adjust the saturation. 0 will
            give a black and white image, 1 will give the original image while
            2 will enhance the saturation by a factor of 2.

    Returns:
vfdev's avatar
vfdev committed
902
        PIL Image or Tensor: Saturation adjusted image.
903
    """
904
905
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_saturation)
906
907
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_saturation(img, saturation_factor)
908

909
    return F_t.adjust_saturation(img, saturation_factor)
910
911


912
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
913
914
915
916
917
918
919
920
921
    """Adjust hue of an image.

    The image hue is adjusted by converting the image to HSV and
    cyclically shifting the intensities in the hue channel (H).
    The image is then converted back to original image mode.

    `hue_factor` is the amount of shift in H channel and must be in the
    interval `[-0.5, 0.5]`.

922
923
924
    See `Hue`_ for more details.

    .. _Hue: https://en.wikipedia.org/wiki/Hue
925
926

    Args:
927
        img (PIL Image or Tensor): Image to be adjusted.
928
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
929
            where ... means it can have an arbitrary number of leading dimensions.
930
            If img is PIL Image mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
931
932
933
            Note: the pixel values of the input image has to be non-negative for conversion to HSV space;
            thus it does not work if you normalize your image to an interval with negative values,
            or use an interpolation that generates negative values before using this function.
934
935
936
937
938
939
940
        hue_factor (float):  How much to shift the hue channel. Should be in
            [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
            HSV space in positive and negative direction respectively.
            0 means no shift. Therefore, both -0.5 and 0.5 will give an image
            with complementary colors while 0 gives the original image.

    Returns:
941
        PIL Image or Tensor: Hue adjusted image.
942
    """
943
944
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_hue)
945
946
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_hue(img, hue_factor)
947

948
    return F_t.adjust_hue(img, hue_factor)
949
950


951
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
952
    r"""Perform gamma correction on an image.
953
954
955
956

    Also known as Power Law Transform. Intensities in RGB mode are adjusted
    based on the following equation:

957
958
959
960
    .. math::
        I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}

    See `Gamma Correction`_ for more details.
961

962
    .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
963
964

    Args:
965
        img (PIL Image or Tensor): PIL Image to be adjusted.
966
967
968
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, modes with transparency (alpha channel) are not supported.
969
970
971
        gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
            gamma larger than 1 make the shadows darker,
            while gamma smaller than 1 make dark regions lighter.
972
        gain (float): The constant multiplier.
973
974
    Returns:
        PIL Image or Tensor: Gamma correction adjusted image.
975
    """
976
977
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_gamma)
978
979
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_gamma(img, gamma, gain)
980

981
    return F_t.adjust_gamma(img, gamma, gain)
982
983


vfdev's avatar
vfdev committed
984
def _get_inverse_affine_matrix(
985
    center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
vfdev's avatar
vfdev committed
986
) -> List[float]:
987
988
    # Helper method to compute inverse matrix for affine transformation

989
990
991
    # Pillow requires inverse affine transformation matrix:
    # Affine matrix is : M = T * C * RotateScaleShear * C^-1
    #
992
993
    # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
    #       C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
994
995
996
    #       RotateScaleShear is rotation with scale and shear matrix
    #
    #       RotateScaleShear(a, s, (sx, sy)) =
997
    #       = R(a) * S(s) * SHy(sy) * SHx(sx)
998
999
    #       = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
    #         [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
1000
1001
1002
1003
1004
    #         [ 0                    , 0                                      , 1 ]
    # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
    # SHx(s) = [1, -tan(s)] and SHy(s) = [1      , 0]
    #          [0, 1      ]              [-tan(s), 1]
    #
1005
    # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1
1006

1007
    rot = math.radians(angle)
1008
1009
    sx = math.radians(shear[0])
    sy = math.radians(shear[1])
1010
1011
1012
1013
1014

    cx, cy = center
    tx, ty = translate

    # RSS without scaling
vfdev's avatar
vfdev committed
1015
1016
1017
1018
    a = math.cos(rot - sy) / math.cos(sy)
    b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
    c = math.sin(rot - sy) / math.cos(sy)
    d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
1019

1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
    if inverted:
        # Inverted rotation matrix with scale and shear
        # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
        matrix = [d, -b, 0.0, -c, a, 0.0]
        matrix = [x / scale for x in matrix]
        # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
        matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
        matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
        # Apply center translation: C * RSS^-1 * C^-1 * T^-1
        matrix[2] += cx
        matrix[5] += cy
    else:
        matrix = [a, b, 0.0, c, d, 0.0]
        matrix = [x * scale for x in matrix]
        # Apply inverse of center translation: RSS * C^-1
        matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy)
        matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy)
        # Apply translation and center : T * C * RSS * C^-1
        matrix[2] += cx + tx
        matrix[5] += cy + ty
1040

vfdev's avatar
vfdev committed
1041
    return matrix
1042

vfdev's avatar
vfdev committed
1043

vfdev's avatar
vfdev committed
1044
def rotate(
1045
1046
1047
1048
1049
1050
1051
    img: Tensor,
    angle: float,
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    expand: bool = False,
    center: Optional[List[int]] = None,
    fill: Optional[List[float]] = None,
    resample: Optional[int] = None,
vfdev's avatar
vfdev committed
1052
1053
) -> Tensor:
    """Rotate the image by angle.
1054
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
1055
1056
1057
1058
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.

    Args:
        img (PIL Image or Tensor): image to be rotated.
1059
        angle (number): rotation angle value in degrees, counter-clockwise.
1060
1061
1062
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1063
1064
            For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
            but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
vfdev's avatar
vfdev committed
1065
1066
1067
1068
        expand (bool, optional): Optional expansion flag.
            If true, expands the output image to make it large enough to hold the entire rotated image.
            If false or omitted, make the output image the same size as the input image.
            Note that the expand flag assumes rotation around the center and no translation.
1069
        center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
vfdev's avatar
vfdev committed
1070
            Default is the center of the image.
1071
1072
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
1073
1074
1075
1076

            .. note::
                In torchscript mode single int/float value is not supported, please use a sequence
                of length 1: ``[value, ]``.
1077
1078
1079
1080
        resample (int, optional):
            .. warning::
                This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``interpolation``
                instead.
vfdev's avatar
vfdev committed
1081
1082
1083
1084
1085
1086
1087

    Returns:
        PIL Image or Tensor: Rotated image.

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

    """
1088
1089
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(rotate)
1090
1091
    if resample is not None:
        warnings.warn(
1092
1093
            "The parameter 'resample' is deprecated since 0.12 and will be removed 0.14. "
            "Please use 'interpolation' instead."
1094
1095
1096
1097
1098
1099
        )
        interpolation = _interpolation_modes_from_int(resample)

    # Backward compatibility with integer value
    if isinstance(interpolation, int):
        warnings.warn(
1100
1101
            "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
            "Please use InterpolationMode enum."
1102
1103
1104
        )
        interpolation = _interpolation_modes_from_int(interpolation)

vfdev's avatar
vfdev committed
1105
1106
1107
1108
1109
1110
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if center is not None and not isinstance(center, (list, tuple)):
        raise TypeError("Argument center should be a sequence")

1111
1112
    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")
1113

vfdev's avatar
vfdev committed
1114
    if not isinstance(img, torch.Tensor):
1115
1116
        pil_interpolation = pil_modes_mapping[interpolation]
        return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill)
vfdev's avatar
vfdev committed
1117
1118
1119

    center_f = [0.0, 0.0]
    if center is not None:
1120
        _, height, width = get_dimensions(img)
1121
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
1122
        center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
1123

vfdev's avatar
vfdev committed
1124
1125
1126
    # due to current incoherence of rotation angle direction between affine and rotate implementations
    # we need to set -angle.
    matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
1127
    return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill)
vfdev's avatar
vfdev committed
1128
1129


vfdev's avatar
vfdev committed
1130
def affine(
1131
1132
1133
1134
1135
1136
1137
1138
1139
    img: Tensor,
    angle: float,
    translate: List[int],
    scale: float,
    shear: List[float],
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    fill: Optional[List[float]] = None,
    resample: Optional[int] = None,
    fillcolor: Optional[List[float]] = None,
1140
    center: Optional[List[int]] = None,
vfdev's avatar
vfdev committed
1141
1142
) -> Tensor:
    """Apply affine transformation on the image keeping image center invariant.
1143
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
1144
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1145
1146

    Args:
vfdev's avatar
vfdev committed
1147
        img (PIL Image or Tensor): image to transform.
1148
1149
        angle (number): rotation angle in degrees between -180 and 180, clockwise direction.
        translate (sequence of integers): horizontal and vertical translations (post-rotation translation)
1150
        scale (float): overall scale
1151
1152
        shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction.
            If a sequence is specified, the first value corresponds to a shear parallel to the x axis, while
vfdev's avatar
vfdev committed
1153
            the second value corresponds to a shear parallel to the y axis.
1154
1155
1156
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1157
1158
            For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
            but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
1159
1160
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
1161
1162
1163
1164

            .. note::
                In torchscript mode single int/float value is not supported, please use a sequence
                of length 1: ``[value, ]``.
1165
1166
1167
1168
1169
1170
1171
        fillcolor (sequence or number, optional):
            .. warning::
                This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``fill`` instead.
        resample (int, optional):
            .. warning::
                This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``interpolation``
                instead.
1172
1173
        center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
            Default is the center of the image.
vfdev's avatar
vfdev committed
1174
1175
1176

    Returns:
        PIL Image or Tensor: Transformed image.
1177
    """
1178
1179
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(affine)
1180
1181
    if resample is not None:
        warnings.warn(
1182
1183
            "The parameter 'resample' is deprecated since 0.12 and will be removed in 0.14. "
            "Please use 'interpolation' instead."
1184
1185
1186
1187
1188
1189
        )
        interpolation = _interpolation_modes_from_int(resample)

    # Backward compatibility with integer value
    if isinstance(interpolation, int):
        warnings.warn(
1190
1191
            "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
            "Please use InterpolationMode enum."
1192
1193
1194
1195
        )
        interpolation = _interpolation_modes_from_int(interpolation)

    if fillcolor is not None:
1196
1197
1198
1199
        warnings.warn(
            "The parameter 'fillcolor' is deprecated since 0.12 and will be removed in 0.14. "
            "Please use 'fill' instead."
        )
1200
1201
        fill = fillcolor

vfdev's avatar
vfdev committed
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if not isinstance(translate, (list, tuple)):
        raise TypeError("Argument translate should be a sequence")

    if len(translate) != 2:
        raise ValueError("Argument translate should be a sequence of length 2")

    if scale <= 0.0:
        raise ValueError("Argument scale should be positive")

    if not isinstance(shear, (numbers.Number, (list, tuple))):
        raise TypeError("Shear should be either a single value or a sequence of two values")

1217
1218
    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")
1219

vfdev's avatar
vfdev committed
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
    if isinstance(angle, int):
        angle = float(angle)

    if isinstance(translate, tuple):
        translate = list(translate)

    if isinstance(shear, numbers.Number):
        shear = [shear, 0.0]

    if isinstance(shear, tuple):
        shear = list(shear)

    if len(shear) == 1:
        shear = [shear[0], shear[0]]

    if len(shear) != 2:
1236
        raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
vfdev's avatar
vfdev committed
1237

1238
1239
1240
    if center is not None and not isinstance(center, (list, tuple)):
        raise TypeError("Argument center should be a sequence")

1241
    _, height, width = get_dimensions(img)
vfdev's avatar
vfdev committed
1242
    if not isinstance(img, torch.Tensor):
1243
        # center = (width * 0.5 + 0.5, height * 0.5 + 0.5)
vfdev's avatar
vfdev committed
1244
1245
        # it is visually better to estimate the center without 0.5 offset
        # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
1246
        if center is None:
1247
            center = [width * 0.5, height * 0.5]
vfdev's avatar
vfdev committed
1248
        matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
1249
1250
        pil_interpolation = pil_modes_mapping[interpolation]
        return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill)
1251

1252
1253
    center_f = [0.0, 0.0]
    if center is not None:
1254
        _, height, width = get_dimensions(img)
1255
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
1256
        center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
1257

1258
    translate_f = [1.0 * t for t in translate]
1259
    matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
1260
    return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill)
1261
1262


1263
@torch.jit.unused
1264
def to_grayscale(img, num_output_channels=1):
1265
    """Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
1266
    This transform does not support torch Tensor.
1267
1268

    Args:
1269
        img (PIL Image): PIL Image to be converted to grayscale.
1270
        num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default is 1.
1271
1272

    Returns:
1273
1274
        PIL Image: Grayscale version of the image.

1275
1276
        - if num_output_channels = 1 : returned image is single channel
        - if num_output_channels = 3 : returned image is 3 channel with r = g = b
1277
    """
1278
1279
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(to_grayscale)
1280
1281
    if isinstance(img, Image.Image):
        return F_pil.to_grayscale(img, num_output_channels)
1282

1283
1284
1285
1286
1287
    raise TypeError("Input should be PIL Image")


def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
    """Convert RGB image to grayscale version of image.
1288
1289
    If the image is torch Tensor, it is expected
    to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301

    Note:
        Please, note that this method supports only RGB images as input. For inputs in other color spaces,
        please, consider using meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.

    Args:
        img (PIL Image or Tensor): RGB Image to be converted to grayscale.
        num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.

    Returns:
        PIL Image or Tensor: Grayscale version of the image.

1302
1303
        - if num_output_channels = 1 : returned image is single channel
        - if num_output_channels = 3 : returned image is 3 channel with r = g = b
1304
    """
1305
1306
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(rgb_to_grayscale)
1307
1308
1309
1310
    if not isinstance(img, torch.Tensor):
        return F_pil.to_grayscale(img, num_output_channels)

    return F_t.rgb_to_grayscale(img, num_output_channels)
1311
1312


1313
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
1314
    """Erase the input Tensor Image with given value.
1315
    This transform does not support PIL Image.
1316
1317
1318
1319
1320
1321
1322
1323

    Args:
        img (Tensor Image): Tensor image of size (C, H, W) to be erased
        i (int): i in (i,j) i.e coordinates of the upper left corner.
        j (int): j in (i,j) i.e coordinates of the upper left corner.
        h (int): Height of the erased region.
        w (int): Width of the erased region.
        v: Erasing value.
Zhun Zhong's avatar
Zhun Zhong committed
1324
        inplace(bool, optional): For in-place operations. By default is set False.
1325
1326
1327
1328

    Returns:
        Tensor Image: Erased image.
    """
1329
1330
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(erase)
1331
    if not isinstance(img, torch.Tensor):
1332
        raise TypeError(f"img should be Tensor Image. Got {type(img)}")
1333

1334
    return F_t.erase(img, i, j, h, w, v, inplace=inplace)
1335
1336
1337


def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor:
1338
1339
1340
    """Performs Gaussian blurring on the image by given kernel.
    If the image is torch Tensor, it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1341
1342
1343
1344
1345

    Args:
        img (PIL Image or Tensor): Image to be blurred
        kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers
            like ``(kx, ky)`` or a single integer for square kernels.
1346
1347
1348
1349

            .. note::
                In torchscript mode kernel_size as single int is not supported, use a sequence of
                length 1: ``[ksize, ]``.
1350
1351
1352
1353
        sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a
            sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
            same sigma in both X/Y directions. If None, then it is computed using
            ``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``.
1354
1355
1356
1357
1358
            Default, None.

            .. note::
                In torchscript mode sigma as single float is
                not supported, use a sequence of length 1: ``[sigma, ]``.
1359
1360
1361
1362

    Returns:
        PIL Image or Tensor: Gaussian Blurred version of the image.
    """
1363
1364
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(gaussian_blur)
1365
    if not isinstance(kernel_size, (int, list, tuple)):
1366
        raise TypeError(f"kernel_size should be int or a sequence of integers. Got {type(kernel_size)}")
1367
1368
1369
    if isinstance(kernel_size, int):
        kernel_size = [kernel_size, kernel_size]
    if len(kernel_size) != 2:
1370
        raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
1371
1372
    for ksize in kernel_size:
        if ksize % 2 == 0 or ksize < 0:
1373
            raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
1374
1375
1376
1377
1378

    if sigma is None:
        sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]

    if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
1379
        raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
1380
1381
1382
1383
1384
    if isinstance(sigma, (int, float)):
        sigma = [float(sigma), float(sigma)]
    if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
        sigma = [sigma[0], sigma[0]]
    if len(sigma) != 2:
1385
        raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}")
1386
    for s in sigma:
1387
        if s <= 0.0:
1388
            raise ValueError(f"sigma should have positive values. Got {sigma}")
1389
1390
1391
1392

    t_img = img
    if not isinstance(img, torch.Tensor):
        if not F_pil._is_pil_image(img):
1393
            raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
1394

1395
        t_img = pil_to_tensor(img)
1396
1397
1398
1399

    output = F_t.gaussian_blur(t_img, kernel_size, sigma)

    if not isinstance(img, torch.Tensor):
1400
        output = to_pil_image(output, mode=img.mode)
1401
    return output
1402
1403
1404


def invert(img: Tensor) -> Tensor:
1405
    """Invert the colors of an RGB/grayscale image.
1406
1407
1408

    Args:
        img (PIL Image or Tensor): Image to have its colors inverted.
1409
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1410
1411
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "L" or "RGB".
1412
1413
1414
1415

    Returns:
        PIL Image or Tensor: Color inverted image.
    """
1416
1417
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(invert)
1418
1419
1420
1421
1422
1423
1424
    if not isinstance(img, torch.Tensor):
        return F_pil.invert(img)

    return F_t.invert(img)


def posterize(img: Tensor, bits: int) -> Tensor:
1425
    """Posterize an image by reducing the number of bits for each color channel.
1426
1427
1428

    Args:
        img (PIL Image or Tensor): Image to have its colors posterized.
1429
            If img is torch Tensor, it should be of type torch.uint8 and
1430
1431
1432
            it is expected to be in [..., 1 or 3, H, W] format, where ... means
            it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "L" or "RGB".
1433
1434
1435
1436
        bits (int): The number of bits to keep for each channel (0-8).
    Returns:
        PIL Image or Tensor: Posterized image.
    """
1437
1438
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(posterize)
1439
    if not (0 <= bits <= 8):
1440
        raise ValueError(f"The number if bits should be between 0 and 8. Got {bits}")
1441
1442
1443
1444
1445
1446
1447
1448

    if not isinstance(img, torch.Tensor):
        return F_pil.posterize(img, bits)

    return F_t.posterize(img, bits)


def solarize(img: Tensor, threshold: float) -> Tensor:
1449
    """Solarize an RGB/grayscale image by inverting all pixel values above a threshold.
1450
1451
1452

    Args:
        img (PIL Image or Tensor): Image to have its colors inverted.
1453
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1454
1455
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "L" or "RGB".
1456
1457
1458
1459
        threshold (float): All pixels equal or above this value are inverted.
    Returns:
        PIL Image or Tensor: Solarized image.
    """
1460
1461
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(solarize)
1462
1463
1464
1465
1466
1467
1468
    if not isinstance(img, torch.Tensor):
        return F_pil.solarize(img, threshold)

    return F_t.solarize(img, threshold)


def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
1469
    """Adjust the sharpness of an image.
1470
1471
1472

    Args:
        img (PIL Image or Tensor): Image to be adjusted.
1473
1474
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
            where ... means it can have an arbitrary number of leading dimensions.
1475
1476
1477
1478
1479
1480
1481
        sharpness_factor (float):  How much to adjust the sharpness. Can be
            any non negative number. 0 gives a blurred image, 1 gives the
            original image while 2 increases the sharpness by a factor of 2.

    Returns:
        PIL Image or Tensor: Sharpness adjusted image.
    """
1482
1483
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_sharpness)
1484
1485
1486
1487
1488
1489
1490
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_sharpness(img, sharpness_factor)

    return F_t.adjust_sharpness(img, sharpness_factor)


def autocontrast(img: Tensor) -> Tensor:
1491
    """Maximize contrast of an image by remapping its
1492
1493
1494
1495
1496
    pixels per channel so that the lowest becomes black and the lightest
    becomes white.

    Args:
        img (PIL Image or Tensor): Image on which autocontrast is applied.
1497
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1498
1499
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "L" or "RGB".
1500
1501
1502
1503

    Returns:
        PIL Image or Tensor: An image that was autocontrasted.
    """
1504
1505
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(autocontrast)
1506
1507
1508
1509
1510
1511
1512
    if not isinstance(img, torch.Tensor):
        return F_pil.autocontrast(img)

    return F_t.autocontrast(img)


def equalize(img: Tensor) -> Tensor:
1513
    """Equalize the histogram of an image by applying
1514
1515
1516
1517
1518
    a non-linear mapping to the input in order to create a uniform
    distribution of grayscale values in the output.

    Args:
        img (PIL Image or Tensor): Image on which equalize is applied.
1519
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1520
            where ... means it can have an arbitrary number of leading dimensions.
1521
            The tensor dtype must be ``torch.uint8`` and values are expected to be in ``[0, 255]``.
1522
            If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
1523
1524
1525
1526

    Returns:
        PIL Image or Tensor: An image that was equalized.
    """
1527
1528
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(equalize)
1529
1530
1531
1532
    if not isinstance(img, torch.Tensor):
        return F_pil.equalize(img)

    return F_t.equalize(img)
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556


def elastic_transform(
    img: Tensor,
    displacement: Tensor,
    interpolation: InterpolationMode = InterpolationMode.BILINEAR,
    fill: Optional[List[float]] = None,
) -> Tensor:
    """Transform a tensor image with elastic transformations.
    Given alpha and sigma, it will generate displacement
    vectors for all pixels based on random offsets. Alpha controls the strength
    and sigma controls the smoothness of the displacements.
    The displacements are added to an identity grid and the resulting grid is
    used to grid_sample from the image.

    Applications:
        Randomly transforms the morphology of objects in images and produces a
        see-through-water-like effect.

    Args:
        img (PIL Image or Tensor): Image on which elastic_transform is applied.
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
1557
        displacement (Tensor): The displacement field. Expected shape is [1, H, W, 2].
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`.
            Default is ``InterpolationMode.BILINEAR``.
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
        fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
            If a tuple of length 3, it is used to fill R, G, B channels respectively.
            This value is only used when the padding_mode is constant.
            Only number is supported for torch Tensor.
            Only int or str or tuple value is supported for PIL Image.
    """
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(elastic_transform)
    # Backward compatibility with integer value
    if isinstance(interpolation, int):
        warnings.warn(
            "Argument interpolation should be of type InterpolationMode instead of int. "
            "Please, use InterpolationMode enum."
        )
        interpolation = _interpolation_modes_from_int(interpolation)

    if not isinstance(displacement, torch.Tensor):
1579
        raise TypeError("Argument displacement should be a Tensor")
1580
1581
1582
1583
1584
1585
1586

    t_img = img
    if not isinstance(img, torch.Tensor):
        if not F_pil._is_pil_image(img):
            raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
        t_img = pil_to_tensor(img)

1587
1588
1589
1590
1591
1592
1593
1594
1595
    shape = t_img.shape
    shape = (1,) + shape[-2:] + (2,)
    if shape != displacement.shape:
        raise ValueError(f"Argument displacement shape should be {shape}, but given {displacement.shape}")

    # TODO: if image shape is [N1, N2, ..., C, H, W] and
    # displacement is [1, H, W, 2] we need to reshape input image
    # such grid_sampler takes internal code for 4D input

1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
    output = F_t.elastic_transform(
        t_img,
        displacement,
        interpolation=interpolation.value,
        fill=fill,
    )

    if not isinstance(img, torch.Tensor):
        output = to_pil_image(output, mode=img.mode)
    return output