functional.py 60.9 KB
Newer Older
1
import math
2
3
import numbers
import warnings
4
from enum import Enum
5
from typing import List, Tuple, Any, Optional
6
7
8

import numpy as np
import torch
9
from PIL import Image
10
11
from torch import Tensor

12
13
14
15
16
try:
    import accimage
except ImportError:
    accimage = None

17
from ..utils import _log_api_usage_once
18
19
20
from . import functional_pil as F_pil
from . import functional_tensor as F_t

21

22
class InterpolationMode(Enum):
23
    """Interpolation modes
24
    Available interpolation methods are ``nearest``, ``bilinear``, ``bicubic``, ``box``, ``hamming``, and ``lanczos``.
25
    """
26

27
28
29
30
31
32
33
34
35
36
    NEAREST = "nearest"
    BILINEAR = "bilinear"
    BICUBIC = "bicubic"
    # For PIL compatibility
    BOX = "box"
    HAMMING = "hamming"
    LANCZOS = "lanczos"


# TODO: Once torchscript supports Enums with staticmethod
37
38
# this can be put into InterpolationMode as staticmethod
def _interpolation_modes_from_int(i: int) -> InterpolationMode:
39
    inverse_modes_mapping = {
40
41
42
43
44
45
        0: InterpolationMode.NEAREST,
        2: InterpolationMode.BILINEAR,
        3: InterpolationMode.BICUBIC,
        4: InterpolationMode.BOX,
        5: InterpolationMode.HAMMING,
        1: InterpolationMode.LANCZOS,
46
47
48
49
50
    }
    return inverse_modes_mapping[i]


pil_modes_mapping = {
51
52
53
54
55
56
    InterpolationMode.NEAREST: 0,
    InterpolationMode.BILINEAR: 2,
    InterpolationMode.BICUBIC: 3,
    InterpolationMode.BOX: 4,
    InterpolationMode.HAMMING: 5,
    InterpolationMode.LANCZOS: 1,
57
58
}

vfdev's avatar
vfdev committed
59
60
61
_is_pil_image = F_pil._is_pil_image


62
63
64
65
66
67
68
69
def get_image_size(img: Tensor) -> List[int]:
    """Returns the size of an image as [width, height].

    Args:
        img (PIL Image or Tensor): The image to be checked.

    Returns:
        List[int]: The image size.
vfdev's avatar
vfdev committed
70
    """
71
72
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(get_image_size)
vfdev's avatar
vfdev committed
73
    if isinstance(img, torch.Tensor):
74
        return F_t.get_image_size(img)
75

76
    return F_pil.get_image_size(img)
77

vfdev's avatar
vfdev committed
78

79
80
81
82
83
84
85
86
def get_image_num_channels(img: Tensor) -> int:
    """Returns the number of channels of an image.

    Args:
        img (PIL Image or Tensor): The image to be checked.

    Returns:
        int: The number of channels.
87
    """
88
89
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(get_image_num_channels)
90
    if isinstance(img, torch.Tensor):
91
        return F_t.get_image_num_channels(img)
92

93
    return F_pil.get_image_num_channels(img)
94
95


vfdev's avatar
vfdev committed
96
97
@torch.jit.unused
def _is_numpy(img: Any) -> bool:
98
99
100
    return isinstance(img, np.ndarray)


vfdev's avatar
vfdev committed
101
102
@torch.jit.unused
def _is_numpy_image(img: Any) -> bool:
103
    return img.ndim in {2, 3}
104
105
106
107


def to_tensor(pic):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
108
    This function does not support torchscript.
109

110
    See :class:`~torchvision.transforms.ToTensor` for more details.
111
112
113
114
115
116
117

    Args:
        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
118
119
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(to_tensor)
120
    if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
121
        raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}")
122

123
    if _is_numpy(pic) and not _is_numpy_image(pic):
124
        raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
125

126
127
    default_float_dtype = torch.get_default_dtype()

128
129
    if isinstance(pic, np.ndarray):
        # handle numpy array
surgan12's avatar
surgan12 committed
130
131
132
        if pic.ndim == 2:
            pic = pic[:, :, None]

133
        img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
134
        # backward compatibility
135
        if isinstance(img, torch.ByteTensor):
136
            return img.to(dtype=default_float_dtype).div(255)
137
138
        else:
            return img
139
140

    if accimage is not None and isinstance(pic, accimage.Image):
141
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
142
        pic.copyto(nppic)
143
        return torch.from_numpy(nppic).to(dtype=default_float_dtype)
144
145

    # handle PIL Image
146
147
    mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32}
    img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
148

149
    if pic.mode == "1":
150
        img = 255 * img
151
    img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
152
    # put it from HWC to CHW format
153
    img = img.permute((2, 0, 1)).contiguous()
154
    if isinstance(img, torch.ByteTensor):
155
        return img.to(dtype=default_float_dtype).div(255)
156
157
158
159
    else:
        return img


160
161
def pil_to_tensor(pic):
    """Convert a ``PIL Image`` to a tensor of the same type.
162
    This function does not support torchscript.
163

vfdev's avatar
vfdev committed
164
    See :class:`~torchvision.transforms.PILToTensor` for more details.
165

166
167
168
169
    .. note::

        A deep copy of the underlying array is performed.

170
171
172
173
174
175
    Args:
        pic (PIL Image): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
176
177
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(pil_to_tensor)
178
    if not F_pil._is_pil_image(pic):
179
        raise TypeError(f"pic should be PIL Image. Got {type(pic)}")
180
181

    if accimage is not None and isinstance(pic, accimage.Image):
182
183
        # accimage format is always uint8 internally, so always return uint8 here
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.uint8)
184
185
186
187
        pic.copyto(nppic)
        return torch.as_tensor(nppic)

    # handle PIL Image
188
    img = torch.as_tensor(np.array(pic, copy=True))
189
190
191
192
193
194
    img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
    # put it from HWC to CHW format
    img = img.permute((2, 0, 1))
    return img


195
196
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
    """Convert a tensor image to the given ``dtype`` and scale the values accordingly
197
    This function does not support PIL Image.
198
199
200
201
202
203

    Args:
        image (torch.Tensor): Image to be converted
        dtype (torch.dtype): Desired data type of the output

    Returns:
vfdev's avatar
vfdev committed
204
        Tensor: Converted image
205
206
207
208
209
210
211
212
213
214
215
216

    .. note::

        When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
        If converted back and forth, this mismatch has no effect.

    Raises:
        RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
            well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
            overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
            of the integer ``dtype``.
    """
217
218
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(convert_image_dtype)
219
    if not isinstance(image, torch.Tensor):
220
        raise TypeError("Input img should be Tensor Image")
221
222

    return F_t.convert_image_dtype(image, dtype)
223
224


225
def to_pil_image(pic, mode=None):
226
    """Convert a tensor or an ndarray to PIL Image. This function does not support torchscript.
227

228
    See :class:`~torchvision.transforms.ToPILImage` for more details.
229
230
231
232
233

    Args:
        pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).

234
    .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
235
236
237
238

    Returns:
        PIL Image: Image converted to PIL Image.
    """
239
240
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(to_pil_image)
241
    if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
242
        raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
243

Varun Agrawal's avatar
Varun Agrawal committed
244
245
    elif isinstance(pic, torch.Tensor):
        if pic.ndimension() not in {2, 3}:
246
            raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndimension()} dimensions.")
Varun Agrawal's avatar
Varun Agrawal committed
247
248
249

        elif pic.ndimension() == 2:
            # if 2D image, add channel dimension (CHW)
Surgan Jandial's avatar
Surgan Jandial committed
250
            pic = pic.unsqueeze(0)
Varun Agrawal's avatar
Varun Agrawal committed
251

252
253
        # check number of channels
        if pic.shape[-3] > 4:
254
            raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-3]} channels.")
255

Varun Agrawal's avatar
Varun Agrawal committed
256
257
    elif isinstance(pic, np.ndarray):
        if pic.ndim not in {2, 3}:
258
            raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
Varun Agrawal's avatar
Varun Agrawal committed
259
260
261
262
263

        elif pic.ndim == 2:
            # if 2D image, add channel dimension (HWC)
            pic = np.expand_dims(pic, 2)

264
265
        # check number of channels
        if pic.shape[-1] > 4:
266
            raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")
267

268
    npimg = pic
Varun Agrawal's avatar
Varun Agrawal committed
269
    if isinstance(pic, torch.Tensor):
270
        if pic.is_floating_point() and mode != "F":
271
272
            pic = pic.mul(255).byte()
        npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))
273
274

    if not isinstance(npimg, np.ndarray):
275
        raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, not {type(npimg)}")
276
277
278
279
280

    if npimg.shape[2] == 1:
        expected_mode = None
        npimg = npimg[:, :, 0]
        if npimg.dtype == np.uint8:
281
            expected_mode = "L"
vfdev's avatar
vfdev committed
282
        elif npimg.dtype == np.int16:
283
            expected_mode = "I;16"
vfdev's avatar
vfdev committed
284
        elif npimg.dtype == np.int32:
285
            expected_mode = "I"
286
        elif npimg.dtype == np.float32:
287
            expected_mode = "F"
288
        if mode is not None and mode != expected_mode:
289
            raise ValueError(f"Incorrect mode ({mode}) supplied for input type {np.dtype}. Should be {expected_mode}")
290
291
        mode = expected_mode

surgan12's avatar
surgan12 committed
292
    elif npimg.shape[2] == 2:
293
        permitted_2_channel_modes = ["LA"]
surgan12's avatar
surgan12 committed
294
        if mode is not None and mode not in permitted_2_channel_modes:
295
            raise ValueError(f"Only modes {permitted_2_channel_modes} are supported for 2D inputs")
surgan12's avatar
surgan12 committed
296
297

        if mode is None and npimg.dtype == np.uint8:
298
            mode = "LA"
surgan12's avatar
surgan12 committed
299

300
    elif npimg.shape[2] == 4:
301
        permitted_4_channel_modes = ["RGBA", "CMYK", "RGBX"]
302
        if mode is not None and mode not in permitted_4_channel_modes:
303
            raise ValueError(f"Only modes {permitted_4_channel_modes} are supported for 4D inputs")
304
305

        if mode is None and npimg.dtype == np.uint8:
306
            mode = "RGBA"
307
    else:
308
        permitted_3_channel_modes = ["RGB", "YCbCr", "HSV"]
309
        if mode is not None and mode not in permitted_3_channel_modes:
310
            raise ValueError(f"Only modes {permitted_3_channel_modes} are supported for 3D inputs")
311
        if mode is None and npimg.dtype == np.uint8:
312
            mode = "RGB"
313
314

    if mode is None:
315
        raise TypeError(f"Input type {npimg.dtype} is not supported")
316
317
318
319

    return Image.fromarray(npimg, mode=mode)


320
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
321
    """Normalize a float tensor image with mean and standard deviation.
322
    This transform does not support PIL Image.
323

324
    .. note::
surgan12's avatar
surgan12 committed
325
        This transform acts out of place by default, i.e., it does not mutates the input tensor.
326

327
    See :class:`~torchvision.transforms.Normalize` for more details.
328
329

    Args:
330
        tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
331
        mean (sequence): Sequence of means for each channel.
332
        std (sequence): Sequence of standard deviations for each channel.
333
        inplace(bool,optional): Bool to make this operation inplace.
334
335
336
337

    Returns:
        Tensor: Normalized Tensor image.
    """
338
339
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(normalize)
340
    if not isinstance(tensor, torch.Tensor):
341
        raise TypeError(f"Input tensor should be a torch tensor. Got {type(tensor)}.")
342

343
    if not tensor.is_floating_point():
344
        raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")
345

346
    if tensor.ndim < 3:
347
        raise ValueError(
348
            f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
349
        )
350

surgan12's avatar
surgan12 committed
351
352
353
    if not inplace:
        tensor = tensor.clone()

354
355
356
    dtype = tensor.dtype
    mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
    std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
357
    if (std == 0).any():
358
        raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
359
    if mean.ndim == 1:
360
        mean = mean.view(-1, 1, 1)
361
    if std.ndim == 1:
362
        std = std.view(-1, 1, 1)
363
    tensor.sub_(mean).div_(std)
364
    return tensor
365
366


367
368
369
370
371
372
373
def resize(
    img: Tensor,
    size: List[int],
    interpolation: InterpolationMode = InterpolationMode.BILINEAR,
    max_size: Optional[int] = None,
    antialias: Optional[bool] = None,
) -> Tensor:
vfdev's avatar
vfdev committed
374
    r"""Resize the input image to the given size.
375
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
376
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
377

378
379
380
381
    .. warning::
        The output image might be different depending on its type: when downsampling, the interpolation of PIL images
        and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
        in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
382
383
        types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
        closer.
384

385
    Args:
vfdev's avatar
vfdev committed
386
        img (PIL Image or Tensor): Image to be resized.
387
388
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), the output size will be matched to this. If size is an int,
Vitaliy Chiley's avatar
Vitaliy Chiley committed
389
            the smaller edge of the image will be matched to this number maintaining
390
            the aspect ratio. i.e, if height > width, then image will be rescaled to
vfdev's avatar
vfdev committed
391
            :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
392
393
394

            .. note::
                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
395
396
397
398
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`.
            Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
399
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
400
401
402
403
        max_size (int, optional): The maximum allowed for the longer edge of
            the resized image: if the longer edge of the image is greater
            than ``max_size`` after being resized according to ``size``, then
            the image is resized again so that the longer edge is equal to
404
            ``max_size``. As a result, ``size`` might be overruled, i.e the
405
406
407
            smaller edge may be shorter than ``size``. This is only supported
            if ``size`` is an int (or a sequence of length 1 in torchscript
            mode).
408
        antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
409
410
411
            is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for
            ``InterpolationMode.BILINEAR`` only mode. This can help making the output for PIL images and tensors
            closer.
412
413
414

            .. warning::
                There is no autodiff support for ``antialias=True`` option with input ``img`` as Tensor.
415
416

    Returns:
vfdev's avatar
vfdev committed
417
        PIL Image or Tensor: Resized image.
418
    """
419
420
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(resize)
421
422
423
    # Backward compatibility with integer value
    if isinstance(interpolation, int):
        warnings.warn(
424
425
            "Argument interpolation should be of type InterpolationMode instead of int. "
            "Please, use InterpolationMode enum."
426
427
428
        )
        interpolation = _interpolation_modes_from_int(interpolation)

429
430
    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")
431

vfdev's avatar
vfdev committed
432
    if not isinstance(img, torch.Tensor):
433
        if antialias is not None and not antialias:
434
            warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
435
        pil_interpolation = pil_modes_mapping[interpolation]
436
        return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size)
vfdev's avatar
vfdev committed
437

438
    return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias)
439
440


441
442
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
    r"""Pad the given image on all sides with the given "pad" value.
443
    If the image is torch Tensor, it is expected
444
445
446
    to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
    at most 3 leading dimensions for mode edge,
    and an arbitrary number of leading dimensions for mode constant
447
448

    Args:
449
        img (PIL Image or Tensor): Image to be padded.
450
451
452
        padding (int or sequence): Padding on each border. If a single int is provided this
            is used to pad all borders. If sequence of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a sequence of length 4 is provided
453
            this is the padding for the left, top, right and bottom borders respectively.
454
455
456
457

            .. note::
                In torchscript mode padding as single int is not supported, use a sequence of
                length 1: ``[padding, ]``.
458
459
460
461
462
        fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
            If a tuple of length 3, it is used to fill R, G, B channels respectively.
            This value is only used when the padding_mode is constant.
            Only number is supported for torch Tensor.
            Only int or str or tuple value is supported for PIL Image.
463
464
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
            Default is constant.
465
466
467

            - constant: pads with a constant value, this value is specified with fill

468
469
            - edge: pads with the last value at the edge of the image.
              If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
470

471
472
473
            - reflect: pads with reflection of image without repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
              will result in [3, 2, 1, 2, 3, 4, 3, 2]
474

475
476
477
            - symmetric: pads with reflection of image repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
              will result in [2, 1, 1, 2, 3, 4, 4, 3]
478
479

    Returns:
480
        PIL Image or Tensor: Padded image.
481
    """
482
483
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(pad)
484
485
    if not isinstance(img, torch.Tensor):
        return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
486

487
    return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
488
489


vfdev's avatar
vfdev committed
490
491
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
    """Crop the given image at specified location and output size.
492
    If the image is torch Tensor, it is expected
493
494
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
    If image size is smaller than output size along any edge, image is padded with 0 and then cropped.
495

496
    Args:
vfdev's avatar
vfdev committed
497
        img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
498
499
500
501
        top (int): Vertical component of the top left corner of the crop box.
        left (int): Horizontal component of the top left corner of the crop box.
        height (int): Height of the crop box.
        width (int): Width of the crop box.
502

503
    Returns:
vfdev's avatar
vfdev committed
504
        PIL Image or Tensor: Cropped image.
505
506
    """

507
508
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(crop)
vfdev's avatar
vfdev committed
509
510
    if not isinstance(img, torch.Tensor):
        return F_pil.crop(img, top, left, height, width)
511

vfdev's avatar
vfdev committed
512
    return F_t.crop(img, top, left, height, width)
513

vfdev's avatar
vfdev committed
514
515
516

def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
    """Crops the given image at the center.
517
    If the image is torch Tensor, it is expected
518
519
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
520

521
    Args:
vfdev's avatar
vfdev committed
522
        img (PIL Image or Tensor): Image to be cropped.
523
        output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
vfdev's avatar
vfdev committed
524
525
            it is used for both directions.

526
    Returns:
vfdev's avatar
vfdev committed
527
        PIL Image or Tensor: Cropped image.
528
    """
529
530
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(center_crop)
531
532
    if isinstance(output_size, numbers.Number):
        output_size = (int(output_size), int(output_size))
vfdev's avatar
vfdev committed
533
534
535
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
        output_size = (output_size[0], output_size[0])

536
    image_width, image_height = get_image_size(img)
537
    crop_height, crop_width = output_size
vfdev's avatar
vfdev committed
538

539
540
541
542
543
544
545
546
    if crop_width > image_width or crop_height > image_height:
        padding_ltrb = [
            (crop_width - image_width) // 2 if crop_width > image_width else 0,
            (crop_height - image_height) // 2 if crop_height > image_height else 0,
            (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
            (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
        ]
        img = pad(img, padding_ltrb, fill=0)  # PIL uses fill value 0
547
        image_width, image_height = get_image_size(img)
548
549
550
        if crop_width == image_width and crop_height == image_height:
            return img

551
552
    crop_top = int(round((image_height - crop_height) / 2.0))
    crop_left = int(round((image_width - crop_width) / 2.0))
553
    return crop(img, crop_top, crop_left, crop_height, crop_width)
554
555


556
def resized_crop(
557
558
559
560
561
562
563
    img: Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
    interpolation: InterpolationMode = InterpolationMode.BILINEAR,
564
565
) -> Tensor:
    """Crop the given image and resize it to desired size.
566
    If the image is torch Tensor, it is expected
567
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
568

569
    Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
570
571

    Args:
572
        img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
573
574
575
576
        top (int): Vertical component of the top left corner of the crop box.
        left (int): Horizontal component of the top left corner of the crop box.
        height (int): Height of the crop box.
        width (int): Width of the crop box.
577
        size (sequence or int): Desired output size. Same semantics as ``resize``.
578
579
580
581
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`.
            Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
582
583
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.

584
    Returns:
585
        PIL Image or Tensor: Cropped image.
586
    """
587
588
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(resized_crop)
589
    img = crop(img, top, left, height, width)
590
591
592
593
    img = resize(img, size, interpolation)
    return img


594
def hflip(img: Tensor) -> Tensor:
595
    """Horizontally flip the given image.
596
597

    Args:
vfdev's avatar
vfdev committed
598
        img (PIL Image or Tensor): Image to be flipped. If img
599
            is a Tensor, it is expected to be in [..., H, W] format,
600
            where ... means it can have an arbitrary number of leading
601
            dimensions.
602
603

    Returns:
vfdev's avatar
vfdev committed
604
        PIL Image or Tensor:  Horizontally flipped image.
605
    """
606
607
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(hflip)
608
609
    if not isinstance(img, torch.Tensor):
        return F_pil.hflip(img)
610

611
    return F_t.hflip(img)
612
613


614
def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[int]]) -> List[float]:
615
616
    """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.

Vitaliy Chiley's avatar
Vitaliy Chiley committed
617
    In Perspective Transform each pixel (x, y) in the original image gets transformed as,
618
619
620
     (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )

    Args:
621
622
623
624
625
        startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
            ``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
        endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
            ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.

626
627
628
    Returns:
        octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
    """
629
630
631
632
633
    a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float)

    for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
        a_matrix[2 * i, :] = torch.tensor([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
        a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
634

635
    b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8)
636
    res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution
637

638
    output: List[float] = res.tolist()
639
    return output
640
641


642
def perspective(
643
644
645
646
647
    img: Tensor,
    startpoints: List[List[int]],
    endpoints: List[List[int]],
    interpolation: InterpolationMode = InterpolationMode.BILINEAR,
    fill: Optional[List[float]] = None,
648
649
) -> Tensor:
    """Perform perspective transform of the given image.
650
    If the image is torch Tensor, it is expected
651
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
652
653

    Args:
654
655
656
657
658
        img (PIL Image or Tensor): Image to be transformed.
        startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
            ``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
        endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
            ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
659
660
661
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
662
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
663
664
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
665
666
667
668

            .. note::
                In torchscript mode single int/float value is not supported, please use a sequence
                of length 1: ``[value, ]``.
669

670
    Returns:
671
        PIL Image or Tensor: transformed Image.
672
    """
673
674
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(perspective)
675

676
    coeffs = _get_perspective_coeffs(startpoints, endpoints)
677

678
679
680
    # Backward compatibility with integer value
    if isinstance(interpolation, int):
        warnings.warn(
681
682
            "Argument interpolation should be of type InterpolationMode instead of int. "
            "Please, use InterpolationMode enum."
683
684
685
        )
        interpolation = _interpolation_modes_from_int(interpolation)

686
687
    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")
688

689
    if not isinstance(img, torch.Tensor):
690
691
        pil_interpolation = pil_modes_mapping[interpolation]
        return F_pil.perspective(img, coeffs, interpolation=pil_interpolation, fill=fill)
692

693
    return F_t.perspective(img, coeffs, interpolation=interpolation.value, fill=fill)
694
695


696
def vflip(img: Tensor) -> Tensor:
697
    """Vertically flip the given image.
698
699

    Args:
vfdev's avatar
vfdev committed
700
        img (PIL Image or Tensor): Image to be flipped. If img
701
            is a Tensor, it is expected to be in [..., H, W] format,
702
            where ... means it can have an arbitrary number of leading
703
            dimensions.
704
705

    Returns:
706
        PIL Image or Tensor:  Vertically flipped image.
707
    """
708
709
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(vflip)
710
711
    if not isinstance(img, torch.Tensor):
        return F_pil.vflip(img)
712

713
    return F_t.vflip(img)
714
715


vfdev's avatar
vfdev committed
716
717
def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
    """Crop the given image into four corners and the central crop.
718
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
719
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
720
721
722
723
724
725

    .. Note::
        This transform returns a tuple of images and there may be a
        mismatch in the number of inputs and targets your ``Dataset`` returns.

    Args:
vfdev's avatar
vfdev committed
726
727
728
        img (PIL Image or Tensor): Image to be cropped.
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
729
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
730

731
    Returns:
732
       tuple: tuple (tl, tr, bl, br, center)
733
       Corresponding top left, top right, bottom left, bottom right and center crop.
734
    """
735
736
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(five_crop)
737
738
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
vfdev's avatar
vfdev committed
739
740
    elif isinstance(size, (tuple, list)) and len(size) == 1:
        size = (size[0], size[0])
741

vfdev's avatar
vfdev committed
742
743
744
    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")

745
    image_width, image_height = get_image_size(img)
746
747
748
749
750
    crop_height, crop_width = size
    if crop_width > image_width or crop_height > image_height:
        msg = "Requested crop size {} is bigger than input size {}"
        raise ValueError(msg.format(size, (image_height, image_width)))

vfdev's avatar
vfdev committed
751
752
753
754
755
756
757
758
    tl = crop(img, 0, 0, crop_height, crop_width)
    tr = crop(img, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop(img, image_height - crop_height, 0, crop_height, crop_width)
    br = crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)

    center = center_crop(img, [crop_height, crop_width])

    return tl, tr, bl, br, center
759
760


vfdev's avatar
vfdev committed
761
762
763
def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[Tensor]:
    """Generate ten cropped images from the given image.
    Crop the given image into four corners and the central crop plus the
764
    flipped version of these (horizontal flipping is used by default).
765
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
766
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
767
768
769
770
771

    .. Note::
        This transform returns a tuple of images and there may be a
        mismatch in the number of inputs and targets your ``Dataset`` returns.

772
    Args:
vfdev's avatar
vfdev committed
773
        img (PIL Image or Tensor): Image to be cropped.
774
        size (sequence or int): Desired output size of the crop. If size is an
775
            int instead of sequence like (h, w), a square crop (size, size) is
776
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
777
        vertical_flip (bool): Use vertical flipping instead of horizontal
778
779

    Returns:
780
        tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
781
782
        Corresponding top left, top right, bottom left, bottom right and
        center crop and same for the flipped image.
783
    """
784
785
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(ten_crop)
786
787
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
vfdev's avatar
vfdev committed
788
789
790
791
792
    elif isinstance(size, (tuple, list)) and len(size) == 1:
        size = (size[0], size[0])

    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")
793
794
795
796
797
798
799
800
801
802
803
804

    first_five = five_crop(img, size)

    if vertical_flip:
        img = vflip(img)
    else:
        img = hflip(img)

    second_five = five_crop(img, size)
    return first_five + second_five


805
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
806
    """Adjust brightness of an image.
807
808

    Args:
vfdev's avatar
vfdev committed
809
        img (PIL Image or Tensor): Image to be adjusted.
810
811
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
            where ... means it can have an arbitrary number of leading dimensions.
812
813
814
815
816
        brightness_factor (float):  How much to adjust the brightness. Can be
            any non negative number. 0 gives a black image, 1 gives the
            original image while 2 increases the brightness by a factor of 2.

    Returns:
vfdev's avatar
vfdev committed
817
        PIL Image or Tensor: Brightness adjusted image.
818
    """
819
820
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_brightness)
821
822
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_brightness(img, brightness_factor)
823

824
    return F_t.adjust_brightness(img, brightness_factor)
825
826


827
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
828
    """Adjust contrast of an image.
829
830

    Args:
vfdev's avatar
vfdev committed
831
        img (PIL Image or Tensor): Image to be adjusted.
832
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
833
            where ... means it can have an arbitrary number of leading dimensions.
834
835
836
837
838
        contrast_factor (float): How much to adjust the contrast. Can be any
            non negative number. 0 gives a solid gray image, 1 gives the
            original image while 2 increases the contrast by a factor of 2.

    Returns:
vfdev's avatar
vfdev committed
839
        PIL Image or Tensor: Contrast adjusted image.
840
    """
841
842
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_contrast)
843
844
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_contrast(img, contrast_factor)
845

846
    return F_t.adjust_contrast(img, contrast_factor)
847
848


849
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
850
851
852
    """Adjust color saturation of an image.

    Args:
vfdev's avatar
vfdev committed
853
        img (PIL Image or Tensor): Image to be adjusted.
854
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
855
            where ... means it can have an arbitrary number of leading dimensions.
856
857
858
859
860
        saturation_factor (float):  How much to adjust the saturation. 0 will
            give a black and white image, 1 will give the original image while
            2 will enhance the saturation by a factor of 2.

    Returns:
vfdev's avatar
vfdev committed
861
        PIL Image or Tensor: Saturation adjusted image.
862
    """
863
864
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_saturation)
865
866
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_saturation(img, saturation_factor)
867

868
    return F_t.adjust_saturation(img, saturation_factor)
869
870


871
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
872
873
874
875
876
877
878
879
880
    """Adjust hue of an image.

    The image hue is adjusted by converting the image to HSV and
    cyclically shifting the intensities in the hue channel (H).
    The image is then converted back to original image mode.

    `hue_factor` is the amount of shift in H channel and must be in the
    interval `[-0.5, 0.5]`.

881
882
883
    See `Hue`_ for more details.

    .. _Hue: https://en.wikipedia.org/wiki/Hue
884
885

    Args:
886
        img (PIL Image or Tensor): Image to be adjusted.
887
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
888
            where ... means it can have an arbitrary number of leading dimensions.
889
            If img is PIL Image mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
890
891
892
893
894
895
896
        hue_factor (float):  How much to shift the hue channel. Should be in
            [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
            HSV space in positive and negative direction respectively.
            0 means no shift. Therefore, both -0.5 and 0.5 will give an image
            with complementary colors while 0 gives the original image.

    Returns:
897
        PIL Image or Tensor: Hue adjusted image.
898
    """
899
900
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_hue)
901
902
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_hue(img, hue_factor)
903

904
    return F_t.adjust_hue(img, hue_factor)
905
906


907
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
908
    r"""Perform gamma correction on an image.
909
910
911
912

    Also known as Power Law Transform. Intensities in RGB mode are adjusted
    based on the following equation:

913
914
915
916
    .. math::
        I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}

    See `Gamma Correction`_ for more details.
917

918
    .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
919
920

    Args:
921
        img (PIL Image or Tensor): PIL Image to be adjusted.
922
923
924
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, modes with transparency (alpha channel) are not supported.
925
926
927
        gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
            gamma larger than 1 make the shadows darker,
            while gamma smaller than 1 make dark regions lighter.
928
        gain (float): The constant multiplier.
929
930
    Returns:
        PIL Image or Tensor: Gamma correction adjusted image.
931
    """
932
933
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_gamma)
934
935
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_gamma(img, gamma, gain)
936

937
    return F_t.adjust_gamma(img, gamma, gain)
938
939


vfdev's avatar
vfdev committed
940
def _get_inverse_affine_matrix(
941
942
943
944
945
    center: List[float],
    angle: float,
    translate: List[float],
    scale: float,
    shear: List[float],
vfdev's avatar
vfdev committed
946
) -> List[float]:
947
948
    # Helper method to compute inverse matrix for affine transformation

949
950
951
    # Pillow requires inverse affine transformation matrix:
    # Affine matrix is : M = T * C * RotateScaleShear * C^-1
    #
952
953
    # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
    #       C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
954
955
956
    #       RotateScaleShear is rotation with scale and shear matrix
    #
    #       RotateScaleShear(a, s, (sx, sy)) =
957
    #       = R(a) * S(s) * SHy(sy) * SHx(sx)
958
959
    #       = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
    #         [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
960
961
962
963
964
    #         [ 0                    , 0                                      , 1 ]
    # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
    # SHx(s) = [1, -tan(s)] and SHy(s) = [1      , 0]
    #          [0, 1      ]              [-tan(s), 1]
    #
965
    # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1
966

967
    rot = math.radians(angle)
968
969
    sx = math.radians(shear[0])
    sy = math.radians(shear[1])
970
971
972
973
974

    cx, cy = center
    tx, ty = translate

    # RSS without scaling
vfdev's avatar
vfdev committed
975
976
977
978
    a = math.cos(rot - sy) / math.cos(sy)
    b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
    c = math.sin(rot - sy) / math.cos(sy)
    d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
979
980

    # Inverted rotation matrix with scale and shear
981
    # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
vfdev's avatar
vfdev committed
982
983
    matrix = [d, -b, 0.0, -c, a, 0.0]
    matrix = [x / scale for x in matrix]
984
985

    # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
vfdev's avatar
vfdev committed
986
987
    matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
    matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
988
989

    # Apply center translation: C * RSS^-1 * C^-1 * T^-1
vfdev's avatar
vfdev committed
990
991
    matrix[2] += cx
    matrix[5] += cy
992

vfdev's avatar
vfdev committed
993
    return matrix
994

vfdev's avatar
vfdev committed
995

vfdev's avatar
vfdev committed
996
def rotate(
997
998
999
1000
1001
1002
1003
    img: Tensor,
    angle: float,
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    expand: bool = False,
    center: Optional[List[int]] = None,
    fill: Optional[List[float]] = None,
    resample: Optional[int] = None,
vfdev's avatar
vfdev committed
1004
1005
) -> Tensor:
    """Rotate the image by angle.
1006
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
1007
1008
1009
1010
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.

    Args:
        img (PIL Image or Tensor): image to be rotated.
1011
        angle (number): rotation angle value in degrees, counter-clockwise.
1012
1013
1014
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1015
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
vfdev's avatar
vfdev committed
1016
1017
1018
1019
        expand (bool, optional): Optional expansion flag.
            If true, expands the output image to make it large enough to hold the entire rotated image.
            If false or omitted, make the output image the same size as the input image.
            Note that the expand flag assumes rotation around the center and no translation.
1020
        center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
vfdev's avatar
vfdev committed
1021
            Default is the center of the image.
1022
1023
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
1024
1025
1026
1027

            .. note::
                In torchscript mode single int/float value is not supported, please use a sequence
                of length 1: ``[value, ]``.
1028
1029
1030
1031
        resample (int, optional):
            .. warning::
                This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``interpolation``
                instead.
vfdev's avatar
vfdev committed
1032
1033
1034
1035
1036
1037
1038

    Returns:
        PIL Image or Tensor: Rotated image.

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

    """
1039
1040
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(rotate)
1041
1042
    if resample is not None:
        warnings.warn(
1043
1044
            "The parameter 'resample' is deprecated since 0.12 and will be removed 0.14. "
            "Please use 'interpolation' instead."
1045
1046
1047
1048
1049
1050
        )
        interpolation = _interpolation_modes_from_int(resample)

    # Backward compatibility with integer value
    if isinstance(interpolation, int):
        warnings.warn(
1051
1052
            "Argument interpolation should be of type InterpolationMode instead of int. "
            "Please, use InterpolationMode enum."
1053
1054
1055
        )
        interpolation = _interpolation_modes_from_int(interpolation)

vfdev's avatar
vfdev committed
1056
1057
1058
1059
1060
1061
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if center is not None and not isinstance(center, (list, tuple)):
        raise TypeError("Argument center should be a sequence")

1062
1063
    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")
1064

vfdev's avatar
vfdev committed
1065
    if not isinstance(img, torch.Tensor):
1066
1067
        pil_interpolation = pil_modes_mapping[interpolation]
        return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill)
vfdev's avatar
vfdev committed
1068
1069
1070

    center_f = [0.0, 0.0]
    if center is not None:
1071
        img_size = get_image_size(img)
1072
1073
1074
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
        center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]

vfdev's avatar
vfdev committed
1075
1076
1077
    # due to current incoherence of rotation angle direction between affine and rotate implementations
    # we need to set -angle.
    matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
1078
    return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill)
vfdev's avatar
vfdev committed
1079
1080


vfdev's avatar
vfdev committed
1081
def affine(
1082
1083
1084
1085
1086
1087
1088
1089
1090
    img: Tensor,
    angle: float,
    translate: List[int],
    scale: float,
    shear: List[float],
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    fill: Optional[List[float]] = None,
    resample: Optional[int] = None,
    fillcolor: Optional[List[float]] = None,
1091
    center: Optional[List[int]] = None,
vfdev's avatar
vfdev committed
1092
1093
) -> Tensor:
    """Apply affine transformation on the image keeping image center invariant.
1094
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
1095
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1096
1097

    Args:
vfdev's avatar
vfdev committed
1098
        img (PIL Image or Tensor): image to transform.
1099
1100
        angle (number): rotation angle in degrees between -180 and 180, clockwise direction.
        translate (sequence of integers): horizontal and vertical translations (post-rotation translation)
1101
        scale (float): overall scale
1102
1103
        shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction.
            If a sequence is specified, the first value corresponds to a shear parallel to the x axis, while
vfdev's avatar
vfdev committed
1104
            the second value corresponds to a shear parallel to the y axis.
1105
1106
1107
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1108
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
1109
1110
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
1111
1112
1113
1114

            .. note::
                In torchscript mode single int/float value is not supported, please use a sequence
                of length 1: ``[value, ]``.
1115
1116
1117
1118
1119
1120
1121
        fillcolor (sequence or number, optional):
            .. warning::
                This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``fill`` instead.
        resample (int, optional):
            .. warning::
                This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``interpolation``
                instead.
1122
1123
        center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
            Default is the center of the image.
vfdev's avatar
vfdev committed
1124
1125
1126

    Returns:
        PIL Image or Tensor: Transformed image.
1127
    """
1128
1129
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(affine)
1130
1131
    if resample is not None:
        warnings.warn(
1132
1133
            "The parameter 'resample' is deprecated since 0.12 and will be removed in 0.14. "
            "Please use 'interpolation' instead."
1134
1135
1136
1137
1138
1139
        )
        interpolation = _interpolation_modes_from_int(resample)

    # Backward compatibility with integer value
    if isinstance(interpolation, int):
        warnings.warn(
1140
1141
            "Argument interpolation should be of type InterpolationMode instead of int. "
            "Please, use InterpolationMode enum."
1142
1143
1144
1145
        )
        interpolation = _interpolation_modes_from_int(interpolation)

    if fillcolor is not None:
1146
1147
1148
1149
        warnings.warn(
            "The parameter 'fillcolor' is deprecated since 0.12 and will be removed in 0.14. "
            "Please use 'fill' instead."
        )
1150
1151
        fill = fillcolor

vfdev's avatar
vfdev committed
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if not isinstance(translate, (list, tuple)):
        raise TypeError("Argument translate should be a sequence")

    if len(translate) != 2:
        raise ValueError("Argument translate should be a sequence of length 2")

    if scale <= 0.0:
        raise ValueError("Argument scale should be positive")

    if not isinstance(shear, (numbers.Number, (list, tuple))):
        raise TypeError("Shear should be either a single value or a sequence of two values")

1167
1168
    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")
1169

vfdev's avatar
vfdev committed
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
    if isinstance(angle, int):
        angle = float(angle)

    if isinstance(translate, tuple):
        translate = list(translate)

    if isinstance(shear, numbers.Number):
        shear = [shear, 0.0]

    if isinstance(shear, tuple):
        shear = list(shear)

    if len(shear) == 1:
        shear = [shear[0], shear[0]]

    if len(shear) != 2:
1186
        raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
vfdev's avatar
vfdev committed
1187

1188
1189
1190
    if center is not None and not isinstance(center, (list, tuple)):
        raise TypeError("Argument center should be a sequence")

1191
    img_size = get_image_size(img)
vfdev's avatar
vfdev committed
1192
1193
1194
1195
    if not isinstance(img, torch.Tensor):
        # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
        # it is visually better to estimate the center without 0.5 offset
        # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
1196
1197
        if center is None:
            center = [img_size[0] * 0.5, img_size[1] * 0.5]
vfdev's avatar
vfdev committed
1198
        matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
1199
1200
        pil_interpolation = pil_modes_mapping[interpolation]
        return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill)
1201

1202
1203
1204
1205
1206
1207
    center_f = [0.0, 0.0]
    if center is not None:
        img_size = get_image_size(img)
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
        center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]

1208
    translate_f = [1.0 * t for t in translate]
1209
    matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
1210
    return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill)
1211
1212


1213
@torch.jit.unused
1214
def to_grayscale(img, num_output_channels=1):
1215
    """Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
1216
    This transform does not support torch Tensor.
1217
1218

    Args:
1219
        img (PIL Image): PIL Image to be converted to grayscale.
1220
        num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default is 1.
1221
1222

    Returns:
1223
1224
        PIL Image: Grayscale version of the image.

1225
1226
        - if num_output_channels = 1 : returned image is single channel
        - if num_output_channels = 3 : returned image is 3 channel with r = g = b
1227
    """
1228
1229
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(to_grayscale)
1230
1231
    if isinstance(img, Image.Image):
        return F_pil.to_grayscale(img, num_output_channels)
1232

1233
1234
1235
1236
1237
    raise TypeError("Input should be PIL Image")


def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
    """Convert RGB image to grayscale version of image.
1238
1239
    If the image is torch Tensor, it is expected
    to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251

    Note:
        Please, note that this method supports only RGB images as input. For inputs in other color spaces,
        please, consider using meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.

    Args:
        img (PIL Image or Tensor): RGB Image to be converted to grayscale.
        num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.

    Returns:
        PIL Image or Tensor: Grayscale version of the image.

1252
1253
        - if num_output_channels = 1 : returned image is single channel
        - if num_output_channels = 3 : returned image is 3 channel with r = g = b
1254
    """
1255
1256
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(rgb_to_grayscale)
1257
1258
1259
1260
    if not isinstance(img, torch.Tensor):
        return F_pil.to_grayscale(img, num_output_channels)

    return F_t.rgb_to_grayscale(img, num_output_channels)
1261
1262


1263
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
1264
    """Erase the input Tensor Image with given value.
1265
    This transform does not support PIL Image.
1266
1267
1268
1269
1270
1271
1272
1273

    Args:
        img (Tensor Image): Tensor image of size (C, H, W) to be erased
        i (int): i in (i,j) i.e coordinates of the upper left corner.
        j (int): j in (i,j) i.e coordinates of the upper left corner.
        h (int): Height of the erased region.
        w (int): Width of the erased region.
        v: Erasing value.
Zhun Zhong's avatar
Zhun Zhong committed
1274
        inplace(bool, optional): For in-place operations. By default is set False.
1275
1276
1277
1278

    Returns:
        Tensor Image: Erased image.
    """
1279
1280
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(erase)
1281
    if not isinstance(img, torch.Tensor):
1282
        raise TypeError(f"img should be Tensor Image. Got {type(img)}")
1283

1284
1285
1286
    if not inplace:
        img = img.clone()

1287
    img[..., i : i + h, j : j + w] = v
1288
    return img
1289
1290
1291


def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor:
1292
1293
1294
    """Performs Gaussian blurring on the image by given kernel.
    If the image is torch Tensor, it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1295
1296
1297
1298
1299

    Args:
        img (PIL Image or Tensor): Image to be blurred
        kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers
            like ``(kx, ky)`` or a single integer for square kernels.
1300
1301
1302
1303

            .. note::
                In torchscript mode kernel_size as single int is not supported, use a sequence of
                length 1: ``[ksize, ]``.
1304
1305
1306
1307
        sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a
            sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
            same sigma in both X/Y directions. If None, then it is computed using
            ``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``.
1308
1309
1310
1311
1312
            Default, None.

            .. note::
                In torchscript mode sigma as single float is
                not supported, use a sequence of length 1: ``[sigma, ]``.
1313
1314
1315
1316

    Returns:
        PIL Image or Tensor: Gaussian Blurred version of the image.
    """
1317
1318
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(gaussian_blur)
1319
    if not isinstance(kernel_size, (int, list, tuple)):
1320
        raise TypeError(f"kernel_size should be int or a sequence of integers. Got {type(kernel_size)}")
1321
1322
1323
    if isinstance(kernel_size, int):
        kernel_size = [kernel_size, kernel_size]
    if len(kernel_size) != 2:
1324
        raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
1325
1326
    for ksize in kernel_size:
        if ksize % 2 == 0 or ksize < 0:
1327
            raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
1328
1329
1330
1331
1332

    if sigma is None:
        sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]

    if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
1333
        raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
1334
1335
1336
1337
1338
    if isinstance(sigma, (int, float)):
        sigma = [float(sigma), float(sigma)]
    if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
        sigma = [sigma[0], sigma[0]]
    if len(sigma) != 2:
1339
        raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}")
1340
    for s in sigma:
1341
        if s <= 0.0:
1342
            raise ValueError(f"sigma should have positive values. Got {sigma}")
1343
1344
1345
1346

    t_img = img
    if not isinstance(img, torch.Tensor):
        if not F_pil._is_pil_image(img):
1347
            raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
1348
1349
1350
1351
1352
1353
1354
1355

        t_img = to_tensor(img)

    output = F_t.gaussian_blur(t_img, kernel_size, sigma)

    if not isinstance(img, torch.Tensor):
        output = to_pil_image(output)
    return output
1356
1357
1358


def invert(img: Tensor) -> Tensor:
1359
    """Invert the colors of an RGB/grayscale image.
1360
1361
1362

    Args:
        img (PIL Image or Tensor): Image to have its colors inverted.
1363
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1364
1365
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "L" or "RGB".
1366
1367
1368
1369

    Returns:
        PIL Image or Tensor: Color inverted image.
    """
1370
1371
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(invert)
1372
1373
1374
1375
1376
1377
1378
    if not isinstance(img, torch.Tensor):
        return F_pil.invert(img)

    return F_t.invert(img)


def posterize(img: Tensor, bits: int) -> Tensor:
1379
    """Posterize an image by reducing the number of bits for each color channel.
1380
1381
1382

    Args:
        img (PIL Image or Tensor): Image to have its colors posterized.
1383
            If img is torch Tensor, it should be of type torch.uint8 and
1384
1385
1386
            it is expected to be in [..., 1 or 3, H, W] format, where ... means
            it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "L" or "RGB".
1387
1388
1389
1390
        bits (int): The number of bits to keep for each channel (0-8).
    Returns:
        PIL Image or Tensor: Posterized image.
    """
1391
1392
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(posterize)
1393
    if not (0 <= bits <= 8):
1394
        raise ValueError(f"The number if bits should be between 0 and 8. Got {bits}")
1395
1396
1397
1398
1399
1400
1401
1402

    if not isinstance(img, torch.Tensor):
        return F_pil.posterize(img, bits)

    return F_t.posterize(img, bits)


def solarize(img: Tensor, threshold: float) -> Tensor:
1403
    """Solarize an RGB/grayscale image by inverting all pixel values above a threshold.
1404
1405
1406

    Args:
        img (PIL Image or Tensor): Image to have its colors inverted.
1407
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1408
1409
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "L" or "RGB".
1410
1411
1412
1413
        threshold (float): All pixels equal or above this value are inverted.
    Returns:
        PIL Image or Tensor: Solarized image.
    """
1414
1415
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(solarize)
1416
1417
1418
1419
1420
1421
1422
    if not isinstance(img, torch.Tensor):
        return F_pil.solarize(img, threshold)

    return F_t.solarize(img, threshold)


def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
1423
    """Adjust the sharpness of an image.
1424
1425
1426

    Args:
        img (PIL Image or Tensor): Image to be adjusted.
1427
1428
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
            where ... means it can have an arbitrary number of leading dimensions.
1429
1430
1431
1432
1433
1434
1435
        sharpness_factor (float):  How much to adjust the sharpness. Can be
            any non negative number. 0 gives a blurred image, 1 gives the
            original image while 2 increases the sharpness by a factor of 2.

    Returns:
        PIL Image or Tensor: Sharpness adjusted image.
    """
1436
1437
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_sharpness)
1438
1439
1440
1441
1442
1443
1444
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_sharpness(img, sharpness_factor)

    return F_t.adjust_sharpness(img, sharpness_factor)


def autocontrast(img: Tensor) -> Tensor:
1445
    """Maximize contrast of an image by remapping its
1446
1447
1448
1449
1450
    pixels per channel so that the lowest becomes black and the lightest
    becomes white.

    Args:
        img (PIL Image or Tensor): Image on which autocontrast is applied.
1451
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1452
1453
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "L" or "RGB".
1454
1455
1456
1457

    Returns:
        PIL Image or Tensor: An image that was autocontrasted.
    """
1458
1459
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(autocontrast)
1460
1461
1462
1463
1464
1465
1466
    if not isinstance(img, torch.Tensor):
        return F_pil.autocontrast(img)

    return F_t.autocontrast(img)


def equalize(img: Tensor) -> Tensor:
1467
    """Equalize the histogram of an image by applying
1468
1469
1470
1471
1472
    a non-linear mapping to the input in order to create a uniform
    distribution of grayscale values in the output.

    Args:
        img (PIL Image or Tensor): Image on which equalize is applied.
1473
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1474
            where ... means it can have an arbitrary number of leading dimensions.
1475
            The tensor dtype must be ``torch.uint8`` and values are expected to be in ``[0, 255]``.
1476
            If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
1477
1478
1479
1480

    Returns:
        PIL Image or Tensor: An image that was equalized.
    """
1481
1482
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(equalize)
1483
1484
1485
1486
    if not isinstance(img, torch.Tensor):
        return F_pil.equalize(img)

    return F_t.equalize(img)