functional.py 60.5 KB
Newer Older
1
import math
2
3
import numbers
import warnings
4
from enum import Enum
5
from typing import List, Tuple, Any, Optional
6
7
8

import numpy as np
import torch
9
from PIL import Image
10
11
from torch import Tensor

12
13
14
15
16
try:
    import accimage
except ImportError:
    accimage = None

17
from ..utils import _log_api_usage_once
18
19
20
from . import functional_pil as F_pil
from . import functional_tensor as F_t

21

22
class InterpolationMode(Enum):
23
    """Interpolation modes
24
    Available interpolation methods are ``nearest``, ``bilinear``, ``bicubic``, ``box``, ``hamming``, and ``lanczos``.
25
    """
26

27
28
29
30
31
32
33
34
35
36
    NEAREST = "nearest"
    BILINEAR = "bilinear"
    BICUBIC = "bicubic"
    # For PIL compatibility
    BOX = "box"
    HAMMING = "hamming"
    LANCZOS = "lanczos"


# TODO: Once torchscript supports Enums with staticmethod
37
38
# this can be put into InterpolationMode as staticmethod
def _interpolation_modes_from_int(i: int) -> InterpolationMode:
39
    inverse_modes_mapping = {
40
41
42
43
44
45
        0: InterpolationMode.NEAREST,
        2: InterpolationMode.BILINEAR,
        3: InterpolationMode.BICUBIC,
        4: InterpolationMode.BOX,
        5: InterpolationMode.HAMMING,
        1: InterpolationMode.LANCZOS,
46
47
48
49
50
    }
    return inverse_modes_mapping[i]


pil_modes_mapping = {
51
52
53
54
55
56
    InterpolationMode.NEAREST: 0,
    InterpolationMode.BILINEAR: 2,
    InterpolationMode.BICUBIC: 3,
    InterpolationMode.BOX: 4,
    InterpolationMode.HAMMING: 5,
    InterpolationMode.LANCZOS: 1,
57
58
}

vfdev's avatar
vfdev committed
59
60
61
_is_pil_image = F_pil._is_pil_image


62
63
64
65
66
67
68
69
def get_image_size(img: Tensor) -> List[int]:
    """Returns the size of an image as [width, height].

    Args:
        img (PIL Image or Tensor): The image to be checked.

    Returns:
        List[int]: The image size.
vfdev's avatar
vfdev committed
70
    """
71
72
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(get_image_size)
vfdev's avatar
vfdev committed
73
    if isinstance(img, torch.Tensor):
74
        return F_t.get_image_size(img)
75

76
    return F_pil.get_image_size(img)
77

vfdev's avatar
vfdev committed
78

79
80
81
82
83
84
85
86
def get_image_num_channels(img: Tensor) -> int:
    """Returns the number of channels of an image.

    Args:
        img (PIL Image or Tensor): The image to be checked.

    Returns:
        int: The number of channels.
87
    """
88
89
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(get_image_num_channels)
90
    if isinstance(img, torch.Tensor):
91
        return F_t.get_image_num_channels(img)
92

93
    return F_pil.get_image_num_channels(img)
94
95


vfdev's avatar
vfdev committed
96
97
@torch.jit.unused
def _is_numpy(img: Any) -> bool:
98
99
100
    return isinstance(img, np.ndarray)


vfdev's avatar
vfdev committed
101
102
@torch.jit.unused
def _is_numpy_image(img: Any) -> bool:
103
    return img.ndim in {2, 3}
104
105
106
107


def to_tensor(pic):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
108
    This function does not support torchscript.
109

110
    See :class:`~torchvision.transforms.ToTensor` for more details.
111
112
113
114
115
116
117

    Args:
        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
118
119
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(to_tensor)
120
    if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
121
        raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}")
122

123
    if _is_numpy(pic) and not _is_numpy_image(pic):
124
        raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
125

126
127
    default_float_dtype = torch.get_default_dtype()

128
129
    if isinstance(pic, np.ndarray):
        # handle numpy array
surgan12's avatar
surgan12 committed
130
131
132
        if pic.ndim == 2:
            pic = pic[:, :, None]

133
        img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
134
        # backward compatibility
135
        if isinstance(img, torch.ByteTensor):
136
            return img.to(dtype=default_float_dtype).div(255)
137
138
        else:
            return img
139
140

    if accimage is not None and isinstance(pic, accimage.Image):
141
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
142
        pic.copyto(nppic)
143
        return torch.from_numpy(nppic).to(dtype=default_float_dtype)
144
145

    # handle PIL Image
146
147
    mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32}
    img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
148

149
    if pic.mode == "1":
150
        img = 255 * img
151
    img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
152
    # put it from HWC to CHW format
153
    img = img.permute((2, 0, 1)).contiguous()
154
    if isinstance(img, torch.ByteTensor):
155
        return img.to(dtype=default_float_dtype).div(255)
156
157
158
159
    else:
        return img


160
161
def pil_to_tensor(pic):
    """Convert a ``PIL Image`` to a tensor of the same type.
162
    This function does not support torchscript.
163

vfdev's avatar
vfdev committed
164
    See :class:`~torchvision.transforms.PILToTensor` for more details.
165

166
167
168
169
    .. note::

        A deep copy of the underlying array is performed.

170
171
172
173
174
175
    Args:
        pic (PIL Image): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
176
177
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(pil_to_tensor)
178
    if not F_pil._is_pil_image(pic):
179
        raise TypeError(f"pic should be PIL Image. Got {type(pic)}")
180
181

    if accimage is not None and isinstance(pic, accimage.Image):
182
183
        # accimage format is always uint8 internally, so always return uint8 here
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.uint8)
184
185
186
187
        pic.copyto(nppic)
        return torch.as_tensor(nppic)

    # handle PIL Image
188
    img = torch.as_tensor(np.array(pic, copy=True))
189
190
191
192
193
194
    img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
    # put it from HWC to CHW format
    img = img.permute((2, 0, 1))
    return img


195
196
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
    """Convert a tensor image to the given ``dtype`` and scale the values accordingly
197
    This function does not support PIL Image.
198
199
200
201
202
203

    Args:
        image (torch.Tensor): Image to be converted
        dtype (torch.dtype): Desired data type of the output

    Returns:
vfdev's avatar
vfdev committed
204
        Tensor: Converted image
205
206
207
208
209
210
211
212
213
214
215
216

    .. note::

        When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
        If converted back and forth, this mismatch has no effect.

    Raises:
        RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
            well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
            overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
            of the integer ``dtype``.
    """
217
218
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(convert_image_dtype)
219
    if not isinstance(image, torch.Tensor):
220
        raise TypeError("Input img should be Tensor Image")
221
222

    return F_t.convert_image_dtype(image, dtype)
223
224


225
def to_pil_image(pic, mode=None):
226
    """Convert a tensor or an ndarray to PIL Image. This function does not support torchscript.
227

228
    See :class:`~torchvision.transforms.ToPILImage` for more details.
229
230
231
232
233

    Args:
        pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).

234
    .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
235
236
237
238

    Returns:
        PIL Image: Image converted to PIL Image.
    """
239
240
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(to_pil_image)
241
    if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
242
        raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
243

Varun Agrawal's avatar
Varun Agrawal committed
244
245
    elif isinstance(pic, torch.Tensor):
        if pic.ndimension() not in {2, 3}:
246
            raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndimension()} dimensions.")
Varun Agrawal's avatar
Varun Agrawal committed
247
248
249

        elif pic.ndimension() == 2:
            # if 2D image, add channel dimension (CHW)
Surgan Jandial's avatar
Surgan Jandial committed
250
            pic = pic.unsqueeze(0)
Varun Agrawal's avatar
Varun Agrawal committed
251

252
253
        # check number of channels
        if pic.shape[-3] > 4:
254
            raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-3]} channels.")
255

Varun Agrawal's avatar
Varun Agrawal committed
256
257
    elif isinstance(pic, np.ndarray):
        if pic.ndim not in {2, 3}:
258
            raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
Varun Agrawal's avatar
Varun Agrawal committed
259
260
261
262
263

        elif pic.ndim == 2:
            # if 2D image, add channel dimension (HWC)
            pic = np.expand_dims(pic, 2)

264
265
        # check number of channels
        if pic.shape[-1] > 4:
266
            raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")
267

268
    npimg = pic
Varun Agrawal's avatar
Varun Agrawal committed
269
    if isinstance(pic, torch.Tensor):
270
        if pic.is_floating_point() and mode != "F":
271
272
            pic = pic.mul(255).byte()
        npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))
273
274

    if not isinstance(npimg, np.ndarray):
275
        raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, not {type(npimg)}")
276
277
278
279
280

    if npimg.shape[2] == 1:
        expected_mode = None
        npimg = npimg[:, :, 0]
        if npimg.dtype == np.uint8:
281
            expected_mode = "L"
vfdev's avatar
vfdev committed
282
        elif npimg.dtype == np.int16:
283
            expected_mode = "I;16"
vfdev's avatar
vfdev committed
284
        elif npimg.dtype == np.int32:
285
            expected_mode = "I"
286
        elif npimg.dtype == np.float32:
287
            expected_mode = "F"
288
        if mode is not None and mode != expected_mode:
289
            raise ValueError(f"Incorrect mode ({mode}) supplied for input type {np.dtype}. Should be {expected_mode}")
290
291
        mode = expected_mode

surgan12's avatar
surgan12 committed
292
    elif npimg.shape[2] == 2:
293
        permitted_2_channel_modes = ["LA"]
surgan12's avatar
surgan12 committed
294
        if mode is not None and mode not in permitted_2_channel_modes:
295
            raise ValueError(f"Only modes {permitted_2_channel_modes} are supported for 2D inputs")
surgan12's avatar
surgan12 committed
296
297

        if mode is None and npimg.dtype == np.uint8:
298
            mode = "LA"
surgan12's avatar
surgan12 committed
299

300
    elif npimg.shape[2] == 4:
301
        permitted_4_channel_modes = ["RGBA", "CMYK", "RGBX"]
302
        if mode is not None and mode not in permitted_4_channel_modes:
303
            raise ValueError(f"Only modes {permitted_4_channel_modes} are supported for 4D inputs")
304
305

        if mode is None and npimg.dtype == np.uint8:
306
            mode = "RGBA"
307
    else:
308
        permitted_3_channel_modes = ["RGB", "YCbCr", "HSV"]
309
        if mode is not None and mode not in permitted_3_channel_modes:
310
            raise ValueError(f"Only modes {permitted_3_channel_modes} are supported for 3D inputs")
311
        if mode is None and npimg.dtype == np.uint8:
312
            mode = "RGB"
313
314

    if mode is None:
315
        raise TypeError(f"Input type {npimg.dtype} is not supported")
316
317
318
319

    return Image.fromarray(npimg, mode=mode)


320
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
321
    """Normalize a float tensor image with mean and standard deviation.
322
    This transform does not support PIL Image.
323

324
    .. note::
surgan12's avatar
surgan12 committed
325
        This transform acts out of place by default, i.e., it does not mutates the input tensor.
326

327
    See :class:`~torchvision.transforms.Normalize` for more details.
328
329

    Args:
330
        tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
331
        mean (sequence): Sequence of means for each channel.
332
        std (sequence): Sequence of standard deviations for each channel.
333
        inplace(bool,optional): Bool to make this operation inplace.
334
335
336
337

    Returns:
        Tensor: Normalized Tensor image.
    """
338
339
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(normalize)
340
    if not isinstance(tensor, torch.Tensor):
341
        raise TypeError(f"Input tensor should be a torch tensor. Got {type(tensor)}.")
342

343
    if not tensor.is_floating_point():
344
        raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")
345

346
    if tensor.ndim < 3:
347
        raise ValueError(
348
            f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
349
        )
350

surgan12's avatar
surgan12 committed
351
352
353
    if not inplace:
        tensor = tensor.clone()

354
355
356
    dtype = tensor.dtype
    mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
    std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
357
    if (std == 0).any():
358
        raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
359
    if mean.ndim == 1:
360
        mean = mean.view(-1, 1, 1)
361
    if std.ndim == 1:
362
        std = std.view(-1, 1, 1)
363
    tensor.sub_(mean).div_(std)
364
    return tensor
365
366


367
368
369
370
371
372
373
def resize(
    img: Tensor,
    size: List[int],
    interpolation: InterpolationMode = InterpolationMode.BILINEAR,
    max_size: Optional[int] = None,
    antialias: Optional[bool] = None,
) -> Tensor:
vfdev's avatar
vfdev committed
374
    r"""Resize the input image to the given size.
375
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
376
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
377

378
379
380
381
    .. warning::
        The output image might be different depending on its type: when downsampling, the interpolation of PIL images
        and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
        in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
382
383
        types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
        closer.
384

385
    Args:
vfdev's avatar
vfdev committed
386
        img (PIL Image or Tensor): Image to be resized.
387
388
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), the output size will be matched to this. If size is an int,
Vitaliy Chiley's avatar
Vitaliy Chiley committed
389
            the smaller edge of the image will be matched to this number maintaining
390
            the aspect ratio. i.e, if height > width, then image will be rescaled to
vfdev's avatar
vfdev committed
391
            :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
392
393
394

            .. note::
                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
395
396
397
398
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`.
            Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
399
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
400
401
402
403
        max_size (int, optional): The maximum allowed for the longer edge of
            the resized image: if the longer edge of the image is greater
            than ``max_size`` after being resized according to ``size``, then
            the image is resized again so that the longer edge is equal to
404
            ``max_size``. As a result, ``size`` might be overruled, i.e the
405
406
407
            smaller edge may be shorter than ``size``. This is only supported
            if ``size`` is an int (or a sequence of length 1 in torchscript
            mode).
408
        antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
409
410
411
            is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for
            ``InterpolationMode.BILINEAR`` only mode. This can help making the output for PIL images and tensors
            closer.
412
413
414

            .. warning::
                There is no autodiff support for ``antialias=True`` option with input ``img`` as Tensor.
415
416

    Returns:
vfdev's avatar
vfdev committed
417
        PIL Image or Tensor: Resized image.
418
    """
419
420
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(resize)
421
422
423
    # Backward compatibility with integer value
    if isinstance(interpolation, int):
        warnings.warn(
424
425
            "Argument interpolation should be of type InterpolationMode instead of int. "
            "Please, use InterpolationMode enum."
426
427
428
        )
        interpolation = _interpolation_modes_from_int(interpolation)

429
430
    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")
431

vfdev's avatar
vfdev committed
432
    if not isinstance(img, torch.Tensor):
433
        if antialias is not None and not antialias:
434
            warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
435
        pil_interpolation = pil_modes_mapping[interpolation]
436
        return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size)
vfdev's avatar
vfdev committed
437

438
    return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias)
439
440


441
442
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
    r"""Pad the given image on all sides with the given "pad" value.
443
    If the image is torch Tensor, it is expected
444
445
446
    to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
    at most 3 leading dimensions for mode edge,
    and an arbitrary number of leading dimensions for mode constant
447
448

    Args:
449
        img (PIL Image or Tensor): Image to be padded.
450
451
452
        padding (int or sequence): Padding on each border. If a single int is provided this
            is used to pad all borders. If sequence of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a sequence of length 4 is provided
453
            this is the padding for the left, top, right and bottom borders respectively.
454
455
456
457

            .. note::
                In torchscript mode padding as single int is not supported, use a sequence of
                length 1: ``[padding, ]``.
458
459
460
461
462
        fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
            If a tuple of length 3, it is used to fill R, G, B channels respectively.
            This value is only used when the padding_mode is constant.
            Only number is supported for torch Tensor.
            Only int or str or tuple value is supported for PIL Image.
463
464
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
            Default is constant.
465
466
467

            - constant: pads with a constant value, this value is specified with fill

468
469
            - edge: pads with the last value at the edge of the image.
              If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
470

471
472
473
            - reflect: pads with reflection of image without repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
              will result in [3, 2, 1, 2, 3, 4, 3, 2]
474

475
476
477
            - symmetric: pads with reflection of image repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
              will result in [2, 1, 1, 2, 3, 4, 4, 3]
478
479

    Returns:
480
        PIL Image or Tensor: Padded image.
481
    """
482
483
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(pad)
484
485
    if not isinstance(img, torch.Tensor):
        return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
486

487
    return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
488
489


vfdev's avatar
vfdev committed
490
491
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
    """Crop the given image at specified location and output size.
492
    If the image is torch Tensor, it is expected
493
494
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
    If image size is smaller than output size along any edge, image is padded with 0 and then cropped.
495

496
    Args:
vfdev's avatar
vfdev committed
497
        img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
498
499
500
501
        top (int): Vertical component of the top left corner of the crop box.
        left (int): Horizontal component of the top left corner of the crop box.
        height (int): Height of the crop box.
        width (int): Width of the crop box.
502

503
    Returns:
vfdev's avatar
vfdev committed
504
        PIL Image or Tensor: Cropped image.
505
506
    """

507
508
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(crop)
vfdev's avatar
vfdev committed
509
510
    if not isinstance(img, torch.Tensor):
        return F_pil.crop(img, top, left, height, width)
511

vfdev's avatar
vfdev committed
512
    return F_t.crop(img, top, left, height, width)
513

vfdev's avatar
vfdev committed
514
515
516

def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
    """Crops the given image at the center.
517
    If the image is torch Tensor, it is expected
518
519
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
520

521
    Args:
vfdev's avatar
vfdev committed
522
        img (PIL Image or Tensor): Image to be cropped.
523
        output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
vfdev's avatar
vfdev committed
524
525
            it is used for both directions.

526
    Returns:
vfdev's avatar
vfdev committed
527
        PIL Image or Tensor: Cropped image.
528
    """
529
530
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(center_crop)
531
532
    if isinstance(output_size, numbers.Number):
        output_size = (int(output_size), int(output_size))
vfdev's avatar
vfdev committed
533
534
535
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
        output_size = (output_size[0], output_size[0])

536
    image_width, image_height = get_image_size(img)
537
    crop_height, crop_width = output_size
vfdev's avatar
vfdev committed
538

539
540
541
542
543
544
545
546
    if crop_width > image_width or crop_height > image_height:
        padding_ltrb = [
            (crop_width - image_width) // 2 if crop_width > image_width else 0,
            (crop_height - image_height) // 2 if crop_height > image_height else 0,
            (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
            (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
        ]
        img = pad(img, padding_ltrb, fill=0)  # PIL uses fill value 0
547
        image_width, image_height = get_image_size(img)
548
549
550
        if crop_width == image_width and crop_height == image_height:
            return img

551
552
    crop_top = int(round((image_height - crop_height) / 2.0))
    crop_left = int(round((image_width - crop_width) / 2.0))
553
    return crop(img, crop_top, crop_left, crop_height, crop_width)
554
555


556
def resized_crop(
557
558
559
560
561
562
563
    img: Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
    interpolation: InterpolationMode = InterpolationMode.BILINEAR,
564
565
) -> Tensor:
    """Crop the given image and resize it to desired size.
566
    If the image is torch Tensor, it is expected
567
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
568

569
    Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
570
571

    Args:
572
        img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
573
574
575
576
        top (int): Vertical component of the top left corner of the crop box.
        left (int): Horizontal component of the top left corner of the crop box.
        height (int): Height of the crop box.
        width (int): Width of the crop box.
577
        size (sequence or int): Desired output size. Same semantics as ``resize``.
578
579
580
581
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`.
            Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
582
583
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.

584
    Returns:
585
        PIL Image or Tensor: Cropped image.
586
    """
587
588
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(resized_crop)
589
    img = crop(img, top, left, height, width)
590
591
592
593
    img = resize(img, size, interpolation)
    return img


594
def hflip(img: Tensor) -> Tensor:
595
    """Horizontally flip the given image.
596
597

    Args:
vfdev's avatar
vfdev committed
598
        img (PIL Image or Tensor): Image to be flipped. If img
599
            is a Tensor, it is expected to be in [..., H, W] format,
600
            where ... means it can have an arbitrary number of leading
601
            dimensions.
602
603

    Returns:
vfdev's avatar
vfdev committed
604
        PIL Image or Tensor:  Horizontally flipped image.
605
    """
606
607
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(hflip)
608
609
    if not isinstance(img, torch.Tensor):
        return F_pil.hflip(img)
610

611
    return F_t.hflip(img)
612
613


614
def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[int]]) -> List[float]:
615
616
    """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.

Vitaliy Chiley's avatar
Vitaliy Chiley committed
617
    In Perspective Transform each pixel (x, y) in the original image gets transformed as,
618
619
620
     (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )

    Args:
621
622
623
624
625
        startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
            ``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
        endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
            ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.

626
627
628
    Returns:
        octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
    """
629
630
631
632
633
    a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float)

    for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
        a_matrix[2 * i, :] = torch.tensor([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
        a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
634

635
    b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8)
636
    res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution
637

638
    output: List[float] = res.tolist()
639
    return output
640
641


642
def perspective(
643
644
645
646
647
    img: Tensor,
    startpoints: List[List[int]],
    endpoints: List[List[int]],
    interpolation: InterpolationMode = InterpolationMode.BILINEAR,
    fill: Optional[List[float]] = None,
648
649
) -> Tensor:
    """Perform perspective transform of the given image.
650
    If the image is torch Tensor, it is expected
651
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
652
653

    Args:
654
655
656
657
658
        img (PIL Image or Tensor): Image to be transformed.
        startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
            ``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
        endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
            ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
659
660
661
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
662
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
663
664
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
665
666
667
668

            .. note::
                In torchscript mode single int/float value is not supported, please use a sequence
                of length 1: ``[value, ]``.
669

670
    Returns:
671
        PIL Image or Tensor: transformed Image.
672
    """
673
674
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(perspective)
675

676
    coeffs = _get_perspective_coeffs(startpoints, endpoints)
677

678
679
680
    # Backward compatibility with integer value
    if isinstance(interpolation, int):
        warnings.warn(
681
682
            "Argument interpolation should be of type InterpolationMode instead of int. "
            "Please, use InterpolationMode enum."
683
684
685
        )
        interpolation = _interpolation_modes_from_int(interpolation)

686
687
    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")
688

689
    if not isinstance(img, torch.Tensor):
690
691
        pil_interpolation = pil_modes_mapping[interpolation]
        return F_pil.perspective(img, coeffs, interpolation=pil_interpolation, fill=fill)
692

693
    return F_t.perspective(img, coeffs, interpolation=interpolation.value, fill=fill)
694
695


696
def vflip(img: Tensor) -> Tensor:
697
    """Vertically flip the given image.
698
699

    Args:
vfdev's avatar
vfdev committed
700
        img (PIL Image or Tensor): Image to be flipped. If img
701
            is a Tensor, it is expected to be in [..., H, W] format,
702
            where ... means it can have an arbitrary number of leading
703
            dimensions.
704
705

    Returns:
706
        PIL Image or Tensor:  Vertically flipped image.
707
    """
708
709
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(vflip)
710
711
    if not isinstance(img, torch.Tensor):
        return F_pil.vflip(img)
712

713
    return F_t.vflip(img)
714
715


vfdev's avatar
vfdev committed
716
717
def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
    """Crop the given image into four corners and the central crop.
718
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
719
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
720
721
722
723
724
725

    .. Note::
        This transform returns a tuple of images and there may be a
        mismatch in the number of inputs and targets your ``Dataset`` returns.

    Args:
vfdev's avatar
vfdev committed
726
727
728
        img (PIL Image or Tensor): Image to be cropped.
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
729
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
730

731
    Returns:
732
       tuple: tuple (tl, tr, bl, br, center)
733
       Corresponding top left, top right, bottom left, bottom right and center crop.
734
    """
735
736
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(five_crop)
737
738
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
vfdev's avatar
vfdev committed
739
740
    elif isinstance(size, (tuple, list)) and len(size) == 1:
        size = (size[0], size[0])
741

vfdev's avatar
vfdev committed
742
743
744
    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")

745
    image_width, image_height = get_image_size(img)
746
747
748
749
750
    crop_height, crop_width = size
    if crop_width > image_width or crop_height > image_height:
        msg = "Requested crop size {} is bigger than input size {}"
        raise ValueError(msg.format(size, (image_height, image_width)))

vfdev's avatar
vfdev committed
751
752
753
754
755
756
757
758
    tl = crop(img, 0, 0, crop_height, crop_width)
    tr = crop(img, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop(img, image_height - crop_height, 0, crop_height, crop_width)
    br = crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)

    center = center_crop(img, [crop_height, crop_width])

    return tl, tr, bl, br, center
759
760


vfdev's avatar
vfdev committed
761
762
763
def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[Tensor]:
    """Generate ten cropped images from the given image.
    Crop the given image into four corners and the central crop plus the
764
    flipped version of these (horizontal flipping is used by default).
765
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
766
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
767
768
769
770
771

    .. Note::
        This transform returns a tuple of images and there may be a
        mismatch in the number of inputs and targets your ``Dataset`` returns.

772
    Args:
vfdev's avatar
vfdev committed
773
        img (PIL Image or Tensor): Image to be cropped.
774
        size (sequence or int): Desired output size of the crop. If size is an
775
            int instead of sequence like (h, w), a square crop (size, size) is
776
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
777
        vertical_flip (bool): Use vertical flipping instead of horizontal
778
779

    Returns:
780
        tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
781
782
        Corresponding top left, top right, bottom left, bottom right and
        center crop and same for the flipped image.
783
    """
784
785
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(ten_crop)
786
787
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
vfdev's avatar
vfdev committed
788
789
790
791
792
    elif isinstance(size, (tuple, list)) and len(size) == 1:
        size = (size[0], size[0])

    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")
793
794
795
796
797
798
799
800
801
802
803
804

    first_five = five_crop(img, size)

    if vertical_flip:
        img = vflip(img)
    else:
        img = hflip(img)

    second_five = five_crop(img, size)
    return first_five + second_five


805
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
806
    """Adjust brightness of an image.
807
808

    Args:
vfdev's avatar
vfdev committed
809
        img (PIL Image or Tensor): Image to be adjusted.
810
811
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
            where ... means it can have an arbitrary number of leading dimensions.
812
813
814
815
816
        brightness_factor (float):  How much to adjust the brightness. Can be
            any non negative number. 0 gives a black image, 1 gives the
            original image while 2 increases the brightness by a factor of 2.

    Returns:
vfdev's avatar
vfdev committed
817
        PIL Image or Tensor: Brightness adjusted image.
818
    """
819
820
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_brightness)
821
822
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_brightness(img, brightness_factor)
823

824
    return F_t.adjust_brightness(img, brightness_factor)
825
826


827
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
828
    """Adjust contrast of an image.
829
830

    Args:
vfdev's avatar
vfdev committed
831
        img (PIL Image or Tensor): Image to be adjusted.
832
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
833
            where ... means it can have an arbitrary number of leading dimensions.
834
835
836
837
838
        contrast_factor (float): How much to adjust the contrast. Can be any
            non negative number. 0 gives a solid gray image, 1 gives the
            original image while 2 increases the contrast by a factor of 2.

    Returns:
vfdev's avatar
vfdev committed
839
        PIL Image or Tensor: Contrast adjusted image.
840
    """
841
842
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_contrast)
843
844
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_contrast(img, contrast_factor)
845

846
    return F_t.adjust_contrast(img, contrast_factor)
847
848


849
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
850
851
852
    """Adjust color saturation of an image.

    Args:
vfdev's avatar
vfdev committed
853
        img (PIL Image or Tensor): Image to be adjusted.
854
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
855
            where ... means it can have an arbitrary number of leading dimensions.
856
857
858
859
860
        saturation_factor (float):  How much to adjust the saturation. 0 will
            give a black and white image, 1 will give the original image while
            2 will enhance the saturation by a factor of 2.

    Returns:
vfdev's avatar
vfdev committed
861
        PIL Image or Tensor: Saturation adjusted image.
862
    """
863
864
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_saturation)
865
866
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_saturation(img, saturation_factor)
867

868
    return F_t.adjust_saturation(img, saturation_factor)
869
870


871
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
872
873
874
875
876
877
878
879
880
    """Adjust hue of an image.

    The image hue is adjusted by converting the image to HSV and
    cyclically shifting the intensities in the hue channel (H).
    The image is then converted back to original image mode.

    `hue_factor` is the amount of shift in H channel and must be in the
    interval `[-0.5, 0.5]`.

881
882
883
    See `Hue`_ for more details.

    .. _Hue: https://en.wikipedia.org/wiki/Hue
884
885

    Args:
886
        img (PIL Image or Tensor): Image to be adjusted.
887
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
888
            where ... means it can have an arbitrary number of leading dimensions.
889
            If img is PIL Image mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
890
891
892
893
894
895
896
        hue_factor (float):  How much to shift the hue channel. Should be in
            [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
            HSV space in positive and negative direction respectively.
            0 means no shift. Therefore, both -0.5 and 0.5 will give an image
            with complementary colors while 0 gives the original image.

    Returns:
897
        PIL Image or Tensor: Hue adjusted image.
898
    """
899
900
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_hue)
901
902
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_hue(img, hue_factor)
903

904
    return F_t.adjust_hue(img, hue_factor)
905
906


907
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
908
    r"""Perform gamma correction on an image.
909
910
911
912

    Also known as Power Law Transform. Intensities in RGB mode are adjusted
    based on the following equation:

913
914
915
916
    .. math::
        I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}

    See `Gamma Correction`_ for more details.
917

918
    .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
919
920

    Args:
921
        img (PIL Image or Tensor): PIL Image to be adjusted.
922
923
924
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, modes with transparency (alpha channel) are not supported.
925
926
927
        gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
            gamma larger than 1 make the shadows darker,
            while gamma smaller than 1 make dark regions lighter.
928
        gain (float): The constant multiplier.
929
930
    Returns:
        PIL Image or Tensor: Gamma correction adjusted image.
931
    """
932
933
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_gamma)
934
935
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_gamma(img, gamma, gain)
936

937
    return F_t.adjust_gamma(img, gamma, gain)
938
939


vfdev's avatar
vfdev committed
940
def _get_inverse_affine_matrix(
941
942
943
944
945
    center: List[float],
    angle: float,
    translate: List[float],
    scale: float,
    shear: List[float],
vfdev's avatar
vfdev committed
946
) -> List[float]:
947
948
    # Helper method to compute inverse matrix for affine transformation

949
950
951
    # Pillow requires inverse affine transformation matrix:
    # Affine matrix is : M = T * C * RotateScaleShear * C^-1
    #
952
953
    # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
    #       C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
954
955
956
    #       RotateScaleShear is rotation with scale and shear matrix
    #
    #       RotateScaleShear(a, s, (sx, sy)) =
957
    #       = R(a) * S(s) * SHy(sy) * SHx(sx)
958
959
    #       = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
    #         [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
960
961
962
963
964
    #         [ 0                    , 0                                      , 1 ]
    # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
    # SHx(s) = [1, -tan(s)] and SHy(s) = [1      , 0]
    #          [0, 1      ]              [-tan(s), 1]
    #
965
    # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1
966

967
    rot = math.radians(angle)
968
969
    sx = math.radians(shear[0])
    sy = math.radians(shear[1])
970
971
972
973
974

    cx, cy = center
    tx, ty = translate

    # RSS without scaling
vfdev's avatar
vfdev committed
975
976
977
978
    a = math.cos(rot - sy) / math.cos(sy)
    b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
    c = math.sin(rot - sy) / math.cos(sy)
    d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
979
980

    # Inverted rotation matrix with scale and shear
981
    # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
vfdev's avatar
vfdev committed
982
983
    matrix = [d, -b, 0.0, -c, a, 0.0]
    matrix = [x / scale for x in matrix]
984
985

    # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
vfdev's avatar
vfdev committed
986
987
    matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
    matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
988
989

    # Apply center translation: C * RSS^-1 * C^-1 * T^-1
vfdev's avatar
vfdev committed
990
991
    matrix[2] += cx
    matrix[5] += cy
992

vfdev's avatar
vfdev committed
993
    return matrix
994

vfdev's avatar
vfdev committed
995

vfdev's avatar
vfdev committed
996
def rotate(
997
998
999
1000
1001
1002
1003
    img: Tensor,
    angle: float,
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    expand: bool = False,
    center: Optional[List[int]] = None,
    fill: Optional[List[float]] = None,
    resample: Optional[int] = None,
vfdev's avatar
vfdev committed
1004
1005
) -> Tensor:
    """Rotate the image by angle.
1006
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
1007
1008
1009
1010
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.

    Args:
        img (PIL Image or Tensor): image to be rotated.
1011
        angle (number): rotation angle value in degrees, counter-clockwise.
1012
1013
1014
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1015
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
vfdev's avatar
vfdev committed
1016
1017
1018
1019
        expand (bool, optional): Optional expansion flag.
            If true, expands the output image to make it large enough to hold the entire rotated image.
            If false or omitted, make the output image the same size as the input image.
            Note that the expand flag assumes rotation around the center and no translation.
1020
        center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
vfdev's avatar
vfdev committed
1021
            Default is the center of the image.
1022
1023
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
1024
1025
1026
1027

            .. note::
                In torchscript mode single int/float value is not supported, please use a sequence
                of length 1: ``[value, ]``.
vfdev's avatar
vfdev committed
1028
1029
1030
1031
1032
1033
1034

    Returns:
        PIL Image or Tensor: Rotated image.

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

    """
1035
1036
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(rotate)
1037
1038
1039
1040
1041
1042
1043
1044
1045
    if resample is not None:
        warnings.warn(
            "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
        )
        interpolation = _interpolation_modes_from_int(resample)

    # Backward compatibility with integer value
    if isinstance(interpolation, int):
        warnings.warn(
1046
1047
            "Argument interpolation should be of type InterpolationMode instead of int. "
            "Please, use InterpolationMode enum."
1048
1049
1050
        )
        interpolation = _interpolation_modes_from_int(interpolation)

vfdev's avatar
vfdev committed
1051
1052
1053
1054
1055
1056
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if center is not None and not isinstance(center, (list, tuple)):
        raise TypeError("Argument center should be a sequence")

1057
1058
    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")
1059

vfdev's avatar
vfdev committed
1060
    if not isinstance(img, torch.Tensor):
1061
1062
        pil_interpolation = pil_modes_mapping[interpolation]
        return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill)
vfdev's avatar
vfdev committed
1063
1064
1065

    center_f = [0.0, 0.0]
    if center is not None:
1066
        img_size = get_image_size(img)
1067
1068
1069
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
        center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]

vfdev's avatar
vfdev committed
1070
1071
1072
    # due to current incoherence of rotation angle direction between affine and rotate implementations
    # we need to set -angle.
    matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
1073
    return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill)
vfdev's avatar
vfdev committed
1074
1075


vfdev's avatar
vfdev committed
1076
def affine(
1077
1078
1079
1080
1081
1082
1083
1084
1085
    img: Tensor,
    angle: float,
    translate: List[int],
    scale: float,
    shear: List[float],
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    fill: Optional[List[float]] = None,
    resample: Optional[int] = None,
    fillcolor: Optional[List[float]] = None,
1086
    center: Optional[List[int]] = None,
vfdev's avatar
vfdev committed
1087
1088
) -> Tensor:
    """Apply affine transformation on the image keeping image center invariant.
1089
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
1090
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1091
1092

    Args:
vfdev's avatar
vfdev committed
1093
        img (PIL Image or Tensor): image to transform.
1094
1095
        angle (number): rotation angle in degrees between -180 and 180, clockwise direction.
        translate (sequence of integers): horizontal and vertical translations (post-rotation translation)
1096
        scale (float): overall scale
1097
1098
        shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction.
            If a sequence is specified, the first value corresponds to a shear parallel to the x axis, while
vfdev's avatar
vfdev committed
1099
            the second value corresponds to a shear parallel to the y axis.
1100
1101
1102
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1103
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
1104
1105
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
1106
1107
1108
1109

            .. note::
                In torchscript mode single int/float value is not supported, please use a sequence
                of length 1: ``[value, ]``.
1110
        fillcolor (sequence, int, float): deprecated argument and will be removed since v0.10.0.
1111
            Please use the ``fill`` parameter instead.
1112
        resample (int, optional): deprecated argument and will be removed since v0.10.0.
1113
            Please use the ``interpolation`` parameter instead.
1114
1115
        center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
            Default is the center of the image.
vfdev's avatar
vfdev committed
1116
1117
1118

    Returns:
        PIL Image or Tensor: Transformed image.
1119
    """
1120
1121
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(affine)
1122
1123
1124
1125
1126
1127
1128
1129
1130
    if resample is not None:
        warnings.warn(
            "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
        )
        interpolation = _interpolation_modes_from_int(resample)

    # Backward compatibility with integer value
    if isinstance(interpolation, int):
        warnings.warn(
1131
1132
            "Argument interpolation should be of type InterpolationMode instead of int. "
            "Please, use InterpolationMode enum."
1133
1134
1135
1136
        )
        interpolation = _interpolation_modes_from_int(interpolation)

    if fillcolor is not None:
1137
        warnings.warn("Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead")
1138
1139
        fill = fillcolor

vfdev's avatar
vfdev committed
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if not isinstance(translate, (list, tuple)):
        raise TypeError("Argument translate should be a sequence")

    if len(translate) != 2:
        raise ValueError("Argument translate should be a sequence of length 2")

    if scale <= 0.0:
        raise ValueError("Argument scale should be positive")

    if not isinstance(shear, (numbers.Number, (list, tuple))):
        raise TypeError("Shear should be either a single value or a sequence of two values")

1155
1156
    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")
1157

vfdev's avatar
vfdev committed
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
    if isinstance(angle, int):
        angle = float(angle)

    if isinstance(translate, tuple):
        translate = list(translate)

    if isinstance(shear, numbers.Number):
        shear = [shear, 0.0]

    if isinstance(shear, tuple):
        shear = list(shear)

    if len(shear) == 1:
        shear = [shear[0], shear[0]]

    if len(shear) != 2:
1174
        raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
vfdev's avatar
vfdev committed
1175

1176
1177
1178
    if center is not None and not isinstance(center, (list, tuple)):
        raise TypeError("Argument center should be a sequence")

1179
    img_size = get_image_size(img)
vfdev's avatar
vfdev committed
1180
1181
1182
1183
    if not isinstance(img, torch.Tensor):
        # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
        # it is visually better to estimate the center without 0.5 offset
        # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
1184
1185
        if center is None:
            center = [img_size[0] * 0.5, img_size[1] * 0.5]
vfdev's avatar
vfdev committed
1186
        matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
1187
1188
        pil_interpolation = pil_modes_mapping[interpolation]
        return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill)
1189

1190
1191
1192
1193
1194
1195
    center_f = [0.0, 0.0]
    if center is not None:
        img_size = get_image_size(img)
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
        center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]

1196
    translate_f = [1.0 * t for t in translate]
1197
    matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
1198
    return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill)
1199
1200


1201
@torch.jit.unused
1202
def to_grayscale(img, num_output_channels=1):
1203
    """Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
1204
    This transform does not support torch Tensor.
1205
1206

    Args:
1207
        img (PIL Image): PIL Image to be converted to grayscale.
1208
        num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default is 1.
1209
1210

    Returns:
1211
1212
        PIL Image: Grayscale version of the image.

1213
1214
        - if num_output_channels = 1 : returned image is single channel
        - if num_output_channels = 3 : returned image is 3 channel with r = g = b
1215
    """
1216
1217
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(to_grayscale)
1218
1219
    if isinstance(img, Image.Image):
        return F_pil.to_grayscale(img, num_output_channels)
1220

1221
1222
1223
1224
1225
    raise TypeError("Input should be PIL Image")


def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
    """Convert RGB image to grayscale version of image.
1226
1227
    If the image is torch Tensor, it is expected
    to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239

    Note:
        Please, note that this method supports only RGB images as input. For inputs in other color spaces,
        please, consider using meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.

    Args:
        img (PIL Image or Tensor): RGB Image to be converted to grayscale.
        num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.

    Returns:
        PIL Image or Tensor: Grayscale version of the image.

1240
1241
        - if num_output_channels = 1 : returned image is single channel
        - if num_output_channels = 3 : returned image is 3 channel with r = g = b
1242
    """
1243
1244
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(rgb_to_grayscale)
1245
1246
1247
1248
    if not isinstance(img, torch.Tensor):
        return F_pil.to_grayscale(img, num_output_channels)

    return F_t.rgb_to_grayscale(img, num_output_channels)
1249
1250


1251
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
1252
    """Erase the input Tensor Image with given value.
1253
    This transform does not support PIL Image.
1254
1255
1256
1257
1258
1259
1260
1261

    Args:
        img (Tensor Image): Tensor image of size (C, H, W) to be erased
        i (int): i in (i,j) i.e coordinates of the upper left corner.
        j (int): j in (i,j) i.e coordinates of the upper left corner.
        h (int): Height of the erased region.
        w (int): Width of the erased region.
        v: Erasing value.
Zhun Zhong's avatar
Zhun Zhong committed
1262
        inplace(bool, optional): For in-place operations. By default is set False.
1263
1264
1265
1266

    Returns:
        Tensor Image: Erased image.
    """
1267
1268
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(erase)
1269
    if not isinstance(img, torch.Tensor):
1270
        raise TypeError(f"img should be Tensor Image. Got {type(img)}")
1271

1272
1273
1274
    if not inplace:
        img = img.clone()

1275
    img[..., i : i + h, j : j + w] = v
1276
    return img
1277
1278
1279


def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor:
1280
1281
1282
    """Performs Gaussian blurring on the image by given kernel.
    If the image is torch Tensor, it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1283
1284
1285
1286
1287

    Args:
        img (PIL Image or Tensor): Image to be blurred
        kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers
            like ``(kx, ky)`` or a single integer for square kernels.
1288
1289
1290
1291

            .. note::
                In torchscript mode kernel_size as single int is not supported, use a sequence of
                length 1: ``[ksize, ]``.
1292
1293
1294
1295
        sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a
            sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
            same sigma in both X/Y directions. If None, then it is computed using
            ``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``.
1296
1297
1298
1299
1300
            Default, None.

            .. note::
                In torchscript mode sigma as single float is
                not supported, use a sequence of length 1: ``[sigma, ]``.
1301
1302
1303
1304

    Returns:
        PIL Image or Tensor: Gaussian Blurred version of the image.
    """
1305
1306
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(gaussian_blur)
1307
    if not isinstance(kernel_size, (int, list, tuple)):
1308
        raise TypeError(f"kernel_size should be int or a sequence of integers. Got {type(kernel_size)}")
1309
1310
1311
    if isinstance(kernel_size, int):
        kernel_size = [kernel_size, kernel_size]
    if len(kernel_size) != 2:
1312
        raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
1313
1314
    for ksize in kernel_size:
        if ksize % 2 == 0 or ksize < 0:
1315
            raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
1316
1317
1318
1319
1320

    if sigma is None:
        sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]

    if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
1321
        raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
1322
1323
1324
1325
1326
    if isinstance(sigma, (int, float)):
        sigma = [float(sigma), float(sigma)]
    if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
        sigma = [sigma[0], sigma[0]]
    if len(sigma) != 2:
1327
        raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}")
1328
    for s in sigma:
1329
        if s <= 0.0:
1330
            raise ValueError(f"sigma should have positive values. Got {sigma}")
1331
1332
1333
1334

    t_img = img
    if not isinstance(img, torch.Tensor):
        if not F_pil._is_pil_image(img):
1335
            raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
1336
1337
1338
1339
1340
1341
1342
1343

        t_img = to_tensor(img)

    output = F_t.gaussian_blur(t_img, kernel_size, sigma)

    if not isinstance(img, torch.Tensor):
        output = to_pil_image(output)
    return output
1344
1345
1346


def invert(img: Tensor) -> Tensor:
1347
    """Invert the colors of an RGB/grayscale image.
1348
1349
1350

    Args:
        img (PIL Image or Tensor): Image to have its colors inverted.
1351
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1352
1353
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "L" or "RGB".
1354
1355
1356
1357

    Returns:
        PIL Image or Tensor: Color inverted image.
    """
1358
1359
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(invert)
1360
1361
1362
1363
1364
1365
1366
    if not isinstance(img, torch.Tensor):
        return F_pil.invert(img)

    return F_t.invert(img)


def posterize(img: Tensor, bits: int) -> Tensor:
1367
    """Posterize an image by reducing the number of bits for each color channel.
1368
1369
1370

    Args:
        img (PIL Image or Tensor): Image to have its colors posterized.
1371
            If img is torch Tensor, it should be of type torch.uint8 and
1372
1373
1374
            it is expected to be in [..., 1 or 3, H, W] format, where ... means
            it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "L" or "RGB".
1375
1376
1377
1378
        bits (int): The number of bits to keep for each channel (0-8).
    Returns:
        PIL Image or Tensor: Posterized image.
    """
1379
1380
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(posterize)
1381
    if not (0 <= bits <= 8):
1382
        raise ValueError(f"The number if bits should be between 0 and 8. Got {bits}")
1383
1384
1385
1386
1387
1388
1389
1390

    if not isinstance(img, torch.Tensor):
        return F_pil.posterize(img, bits)

    return F_t.posterize(img, bits)


def solarize(img: Tensor, threshold: float) -> Tensor:
1391
    """Solarize an RGB/grayscale image by inverting all pixel values above a threshold.
1392
1393
1394

    Args:
        img (PIL Image or Tensor): Image to have its colors inverted.
1395
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1396
1397
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "L" or "RGB".
1398
1399
1400
1401
        threshold (float): All pixels equal or above this value are inverted.
    Returns:
        PIL Image or Tensor: Solarized image.
    """
1402
1403
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(solarize)
1404
1405
1406
1407
1408
1409
1410
    if not isinstance(img, torch.Tensor):
        return F_pil.solarize(img, threshold)

    return F_t.solarize(img, threshold)


def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
1411
    """Adjust the sharpness of an image.
1412
1413
1414

    Args:
        img (PIL Image or Tensor): Image to be adjusted.
1415
1416
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
            where ... means it can have an arbitrary number of leading dimensions.
1417
1418
1419
1420
1421
1422
1423
        sharpness_factor (float):  How much to adjust the sharpness. Can be
            any non negative number. 0 gives a blurred image, 1 gives the
            original image while 2 increases the sharpness by a factor of 2.

    Returns:
        PIL Image or Tensor: Sharpness adjusted image.
    """
1424
1425
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(adjust_sharpness)
1426
1427
1428
1429
1430
1431
1432
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_sharpness(img, sharpness_factor)

    return F_t.adjust_sharpness(img, sharpness_factor)


def autocontrast(img: Tensor) -> Tensor:
1433
    """Maximize contrast of an image by remapping its
1434
1435
1436
1437
1438
    pixels per channel so that the lowest becomes black and the lightest
    becomes white.

    Args:
        img (PIL Image or Tensor): Image on which autocontrast is applied.
1439
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1440
1441
            where ... means it can have an arbitrary number of leading dimensions.
            If img is PIL Image, it is expected to be in mode "L" or "RGB".
1442
1443
1444
1445

    Returns:
        PIL Image or Tensor: An image that was autocontrasted.
    """
1446
1447
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(autocontrast)
1448
1449
1450
1451
1452
1453
1454
    if not isinstance(img, torch.Tensor):
        return F_pil.autocontrast(img)

    return F_t.autocontrast(img)


def equalize(img: Tensor) -> Tensor:
1455
    """Equalize the histogram of an image by applying
1456
1457
1458
1459
1460
    a non-linear mapping to the input in order to create a uniform
    distribution of grayscale values in the output.

    Args:
        img (PIL Image or Tensor): Image on which equalize is applied.
1461
            If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1462
            where ... means it can have an arbitrary number of leading dimensions.
1463
            The tensor dtype must be ``torch.uint8`` and values are expected to be in ``[0, 255]``.
1464
            If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
1465
1466
1467
1468

    Returns:
        PIL Image or Tensor: An image that was equalized.
    """
1469
1470
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(equalize)
1471
1472
1473
1474
    if not isinstance(img, torch.Tensor):
        return F_pil.equalize(img)

    return F_t.equalize(img)