functional.py 38.5 KB
Newer Older
1
import math
2
3
import numbers
import warnings
vfdev's avatar
vfdev committed
4
from typing import Any, Optional
5
6

import numpy as np
vfdev's avatar
vfdev committed
7
from PIL import Image
8
9
10

import torch
from torch import Tensor
vfdev's avatar
vfdev committed
11
from torch.jit.annotations import List, Tuple
12

13
14
15
16
17
try:
    import accimage
except ImportError:
    accimage = None

18
19
20
from . import functional_pil as F_pil
from . import functional_tensor as F_t

21

vfdev's avatar
vfdev committed
22
_is_pil_image = F_pil._is_pil_image
vfdev's avatar
vfdev committed
23
_parse_fill = F_pil._parse_fill
vfdev's avatar
vfdev committed
24
25
26
27
28
29
30


def _get_image_size(img: Tensor) -> List[int]:
    """Returns image sizea as (w, h)
    """
    if isinstance(img, torch.Tensor):
        return F_t._get_image_size(img)
31

vfdev's avatar
vfdev committed
32
    return F_pil._get_image_size(img)
33

vfdev's avatar
vfdev committed
34
35
36

@torch.jit.unused
def _is_numpy(img: Any) -> bool:
37
38
39
    return isinstance(img, np.ndarray)


vfdev's avatar
vfdev committed
40
41
@torch.jit.unused
def _is_numpy_image(img: Any) -> bool:
42
    return img.ndim in {2, 3}
43
44
45
46
47
48
49
50
51
52
53
54
55


def to_tensor(pic):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

    See ``ToTensor`` for more details.

    Args:
        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
vfdev's avatar
vfdev committed
56
    if not(F_pil._is_pil_image(pic) or _is_numpy(pic)):
57
58
        raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))

59
60
61
    if _is_numpy(pic) and not _is_numpy_image(pic):
        raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))

62
63
    if isinstance(pic, np.ndarray):
        # handle numpy array
surgan12's avatar
surgan12 committed
64
65
66
        if pic.ndim == 2:
            pic = pic[:, :, None]

67
68
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
69
70
71
72
        if isinstance(img, torch.ByteTensor):
            return img.float().div(255)
        else:
            return img
73
74
75
76
77
78
79
80
81
82
83

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
84
85
    elif pic.mode == 'F':
        img = torch.from_numpy(np.array(pic, np.float32, copy=False))
86
87
    elif pic.mode == '1':
        img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
88
89
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
90
91

    img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
92
    # put it from HWC to CHW format
93
    img = img.permute((2, 0, 1)).contiguous()
94
95
96
97
98
99
    if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
    else:
        return img


100
101
102
103
104
105
106
107
108
109
110
def pil_to_tensor(pic):
    """Convert a ``PIL Image`` to a tensor of the same type.

    See ``AsTensor`` for more details.

    Args:
        pic (PIL Image): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
vfdev's avatar
vfdev committed
111
    if not(F_pil._is_pil_image(pic)):
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        raise TypeError('pic should be PIL Image. Got {}'.format(type(pic)))

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.as_tensor(nppic)

    # handle PIL Image
    img = torch.as_tensor(np.asarray(pic))
    img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
    # put it from HWC to CHW format
    img = img.permute((2, 0, 1))
    return img


127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
    """Convert a tensor image to the given ``dtype`` and scale the values accordingly

    Args:
        image (torch.Tensor): Image to be converted
        dtype (torch.dtype): Desired data type of the output

    Returns:
        (torch.Tensor): Converted image

    .. note::

        When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
        If converted back and forth, this mismatch has no effect.

    Raises:
        RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
            well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
            overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
            of the integer ``dtype``.
    """
    if image.dtype == dtype:
        return image

    if image.dtype.is_floating_point:
        # float to float
        if dtype.is_floating_point:
            return image.to(dtype)

        # float to int
        if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
            image.dtype == torch.float64 and dtype == torch.int64
        ):
            msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
            raise RuntimeError(msg)

163
164
165
166
167
        # https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
        # For data in the range 0-1, (float * 255).to(uint) is only 255
        # when float is exactly 1.0.
        # `max + 1 - epsilon` provides more evenly distributed mapping of
        # ranges of floats to ints.
168
        eps = 1e-3
169
170
        result = image.mul(torch.iinfo(dtype).max + 1 - eps)
        return result.to(dtype)
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    else:
        # int to float
        if dtype.is_floating_point:
            max = torch.iinfo(image.dtype).max
            image = image.to(dtype)
            return image / max

        # int to int
        input_max = torch.iinfo(image.dtype).max
        output_max = torch.iinfo(dtype).max

        if input_max > output_max:
            factor = (input_max + 1) // (output_max + 1)
            image = image // factor
            return image.to(dtype)
        else:
            factor = (output_max + 1) // (input_max + 1)
            image = image.to(dtype)
            return image * factor


192
193
194
def to_pil_image(pic, mode=None):
    """Convert a tensor or an ndarray to PIL Image.

195
    See :class:`~torchvision.transforms.ToPILImage` for more details.
196
197
198
199
200

    Args:
        pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).

201
    .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
202
203
204
205

    Returns:
        PIL Image: Image converted to PIL Image.
    """
Varun Agrawal's avatar
Varun Agrawal committed
206
    if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
207
208
        raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))

Varun Agrawal's avatar
Varun Agrawal committed
209
210
211
212
213
214
    elif isinstance(pic, torch.Tensor):
        if pic.ndimension() not in {2, 3}:
            raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension()))

        elif pic.ndimension() == 2:
            # if 2D image, add channel dimension (CHW)
Surgan Jandial's avatar
Surgan Jandial committed
215
            pic = pic.unsqueeze(0)
Varun Agrawal's avatar
Varun Agrawal committed
216
217
218
219
220
221
222
223
224

    elif isinstance(pic, np.ndarray):
        if pic.ndim not in {2, 3}:
            raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))

        elif pic.ndim == 2:
            # if 2D image, add channel dimension (HWC)
            pic = np.expand_dims(pic, 2)

225
    npimg = pic
226
    if isinstance(pic, torch.FloatTensor) and mode != 'F':
227
        pic = pic.mul(255).byte()
Varun Agrawal's avatar
Varun Agrawal committed
228
    if isinstance(pic, torch.Tensor):
229
230
231
232
233
234
235
236
237
238
239
        npimg = np.transpose(pic.numpy(), (1, 2, 0))

    if not isinstance(npimg, np.ndarray):
        raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' +
                        'not {}'.format(type(npimg)))

    if npimg.shape[2] == 1:
        expected_mode = None
        npimg = npimg[:, :, 0]
        if npimg.dtype == np.uint8:
            expected_mode = 'L'
vfdev's avatar
vfdev committed
240
        elif npimg.dtype == np.int16:
241
            expected_mode = 'I;16'
vfdev's avatar
vfdev committed
242
        elif npimg.dtype == np.int32:
243
244
245
246
247
248
249
250
            expected_mode = 'I'
        elif npimg.dtype == np.float32:
            expected_mode = 'F'
        if mode is not None and mode != expected_mode:
            raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}"
                             .format(mode, np.dtype, expected_mode))
        mode = expected_mode

surgan12's avatar
surgan12 committed
251
252
253
254
255
256
257
258
    elif npimg.shape[2] == 2:
        permitted_2_channel_modes = ['LA']
        if mode is not None and mode not in permitted_2_channel_modes:
            raise ValueError("Only modes {} are supported for 2D inputs".format(permitted_2_channel_modes))

        if mode is None and npimg.dtype == np.uint8:
            mode = 'LA'

259
    elif npimg.shape[2] == 4:
surgan12's avatar
surgan12 committed
260
        permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX']
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
        if mode is not None and mode not in permitted_4_channel_modes:
            raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))

        if mode is None and npimg.dtype == np.uint8:
            mode = 'RGBA'
    else:
        permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
        if mode is not None and mode not in permitted_3_channel_modes:
            raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
        if mode is None and npimg.dtype == np.uint8:
            mode = 'RGB'

    if mode is None:
        raise TypeError('Input type {} is not supported'.format(npimg.dtype))

    return Image.fromarray(npimg, mode=mode)


surgan12's avatar
surgan12 committed
279
def normalize(tensor, mean, std, inplace=False):
280
281
    """Normalize a tensor image with mean and standard deviation.

282
    .. note::
surgan12's avatar
surgan12 committed
283
        This transform acts out of place by default, i.e., it does not mutates the input tensor.
284

285
    See :class:`~torchvision.transforms.Normalize` for more details.
286
287
288
289

    Args:
        tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        mean (sequence): Sequence of means for each channel.
290
        std (sequence): Sequence of standard deviations for each channel.
291
        inplace(bool,optional): Bool to make this operation inplace.
292
293
294
295

    Returns:
        Tensor: Normalized Tensor image.
    """
296
297
    if not torch.is_tensor(tensor):
        raise TypeError('tensor should be a torch tensor. Got {}.'.format(type(tensor)))
298

299
300
301
    if tensor.ndimension() != 3:
        raise ValueError('Expected tensor to be a tensor image of size (C, H, W). Got tensor.size() = '
                         '{}.'.format(tensor.size()))
302

surgan12's avatar
surgan12 committed
303
304
305
    if not inplace:
        tensor = tensor.clone()

306
307
308
    dtype = tensor.dtype
    mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
    std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
309
310
    if (std == 0).any():
        raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
311
312
313
314
315
    if mean.ndim == 1:
        mean = mean[:, None, None]
    if std.ndim == 1:
        std = std[:, None, None]
    tensor.sub_(mean).div_(std)
316
    return tensor
317
318


vfdev's avatar
vfdev committed
319
def resize(img: Tensor, size: List[int], interpolation: int = Image.BILINEAR) -> Tensor:
vfdev's avatar
vfdev committed
320
321
322
    r"""Resize the input image to the given size.
    The image can be a PIL Image or a torch Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
323
324

    Args:
vfdev's avatar
vfdev committed
325
        img (PIL Image or Tensor): Image to be resized.
326
327
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), the output size will be matched to this. If size is an int,
Vitaliy Chiley's avatar
Vitaliy Chiley committed
328
            the smaller edge of the image will be matched to this number maintaining
329
            the aspect ratio. i.e, if height > width, then image will be rescaled to
vfdev's avatar
vfdev committed
330
331
332
            :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
            In torchscript mode padding as single int is not supported, use a tuple or
            list of length 1: ``[size, ]``.
vfdev's avatar
vfdev committed
333
334
335
        interpolation (int, optional): Desired interpolation enum defined by `filters`_.
            Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR``
            and ``PIL.Image.BICUBIC`` are supported.
336
337

    Returns:
vfdev's avatar
vfdev committed
338
        PIL Image or Tensor: Resized image.
339
    """
vfdev's avatar
vfdev committed
340
341
342
343
    if not isinstance(img, torch.Tensor):
        return F_pil.resize(img, size=size, interpolation=interpolation)

    return F_t.resize(img, size=size, interpolation=interpolation)
344
345
346
347
348
349
350
351


def scale(*args, **kwargs):
    warnings.warn("The use of the transforms.Scale transform is deprecated, " +
                  "please use transforms.Resize instead.")
    return resize(*args, **kwargs)


352
353
354
355
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
    r"""Pad the given image on all sides with the given "pad" value.
    The image can be a PIL Image or a torch Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
356
357

    Args:
358
359
        img (PIL Image or Tensor): Image to be padded.
        padding (int or tuple or list): Padding on each border. If a single int is provided this
360
361
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
362
363
364
365
            this is the padding for the left, top, right and bottom borders respectively.
            In torchscript mode padding as single int is not supported, use a tuple or
            list of length 1: ``[padding, ]``.
        fill (int or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
366
            length 3, it is used to fill R, G, B channels respectively.
367
            This value is only used when the padding_mode is constant. Only int value is supported for Tensors.
368
        padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
vfdev's avatar
vfdev committed
369
            Mode symmetric is not yet supported for Tensor inputs.
370
371
372
373
374
375
376
377
378
379
380
381
382
383

            - constant: pads with a constant value, this value is specified with fill

            - edge: pads with the last value on the edge of the image

            - reflect: pads with reflection of image (without repeating the last value on the edge)

                       padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
                       will result in [3, 2, 1, 2, 3, 4, 3, 2]

            - symmetric: pads with reflection of image (repeating the last value on the edge)

                         padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
                         will result in [2, 1, 1, 2, 3, 4, 4, 3]
384
385

    Returns:
386
        PIL Image or Tensor: Padded image.
387
    """
388
389
    if not isinstance(img, torch.Tensor):
        return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
390

391
    return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
392
393


vfdev's avatar
vfdev committed
394
395
396
397
398
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
    """Crop the given image at specified location and output size.
    The image can be a PIL Image or a Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
399

400
    Args:
vfdev's avatar
vfdev committed
401
        img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
402
403
404
405
        top (int): Vertical component of the top left corner of the crop box.
        left (int): Horizontal component of the top left corner of the crop box.
        height (int): Height of the crop box.
        width (int): Width of the crop box.
406

407
    Returns:
vfdev's avatar
vfdev committed
408
        PIL Image or Tensor: Cropped image.
409
410
    """

vfdev's avatar
vfdev committed
411
412
    if not isinstance(img, torch.Tensor):
        return F_pil.crop(img, top, left, height, width)
413

vfdev's avatar
vfdev committed
414
    return F_t.crop(img, top, left, height, width)
415

vfdev's avatar
vfdev committed
416
417
418
419
420

def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
    """Crops the given image at the center.
    The image can be a PIL Image or a Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
421

422
    Args:
vfdev's avatar
vfdev committed
423
424
425
426
        img (PIL Image or Tensor): Image to be cropped.
        output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int
            it is used for both directions.

427
    Returns:
vfdev's avatar
vfdev committed
428
        PIL Image or Tensor: Cropped image.
429
    """
430
431
    if isinstance(output_size, numbers.Number):
        output_size = (int(output_size), int(output_size))
vfdev's avatar
vfdev committed
432
433
434
435
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
        output_size = (output_size[0], output_size[0])

    image_width, image_height = _get_image_size(img)
436
    crop_height, crop_width = output_size
vfdev's avatar
vfdev committed
437
438
439
440
441
442
443
444
445

    # crop_top = int(round((image_height - crop_height) / 2.))
    # Result can be different between python func and scripted func
    # Temporary workaround:
    crop_top = int((image_height - crop_height + 1) * 0.5)
    # crop_left = int(round((image_width - crop_width) / 2.))
    # Result can be different between python func and scripted func
    # Temporary workaround:
    crop_left = int((image_width - crop_width + 1) * 0.5)
446
    return crop(img, crop_top, crop_left, crop_height, crop_width)
447
448


449
450
451
452
453
454
def resized_crop(
        img: Tensor, top: int, left: int, height: int, width: int, size: List[int], interpolation: int = Image.BILINEAR
) -> Tensor:
    """Crop the given image and resize it to desired size.
    The image can be a PIL Image or a Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
455

456
    Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
457
458

    Args:
459
        img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
460
461
462
463
        top (int): Vertical component of the top left corner of the crop box.
        left (int): Horizontal component of the top left corner of the crop box.
        height (int): Height of the crop box.
        width (int): Width of the crop box.
464
        size (sequence or int): Desired output size. Same semantics as ``resize``.
vfdev's avatar
vfdev committed
465
466
467
        interpolation (int, optional): Desired interpolation enum defined by `filters`_.
            Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR``
            and ``PIL.Image.BICUBIC`` are supported.
468
    Returns:
469
        PIL Image or Tensor: Cropped image.
470
    """
471
    img = crop(img, top, left, height, width)
472
473
474
475
    img = resize(img, size, interpolation)
    return img


476
def hflip(img: Tensor) -> Tensor:
vfdev's avatar
vfdev committed
477
    """Horizontally flip the given PIL Image or Tensor.
478
479

    Args:
vfdev's avatar
vfdev committed
480
        img (PIL Image or Tensor): Image to be flipped. If img
481
482
483
            is a Tensor, it is expected to be in [..., H, W] format,
            where ... means it can have an arbitrary number of trailing
            dimensions.
484
485

    Returns:
vfdev's avatar
vfdev committed
486
        PIL Image or Tensor:  Horizontally flipped image.
487
    """
488
489
    if not isinstance(img, torch.Tensor):
        return F_pil.hflip(img)
490

491
    return F_t.hflip(img)
492
493


494
495
496
def _get_perspective_coeffs(startpoints, endpoints):
    """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.

Vitaliy Chiley's avatar
Vitaliy Chiley committed
497
    In Perspective Transform each pixel (x, y) in the original image gets transformed as,
498
499
500
     (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )

    Args:
Vitaliy Chiley's avatar
Vitaliy Chiley committed
501
        List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
vfdev's avatar
vfdev committed
502
        List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image
503
504
505
506
507
508
509
510
511
512
513
    Returns:
        octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
    """
    matrix = []

    for p1, p2 in zip(endpoints, startpoints):
        matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
        matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])

    A = torch.tensor(matrix, dtype=torch.float)
    B = torch.tensor(startpoints, dtype=torch.float).view(8)
514
    res = torch.lstsq(B, A)[0]
515
516
517
    return res.squeeze_(1).tolist()


518
def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=None):
519
520
521
522
    """Perform perspective transform of the given PIL Image.

    Args:
        img (PIL Image): Image to be transformed.
Vitaliy Chiley's avatar
Vitaliy Chiley committed
523
        startpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the original image
524
        endpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image
525
        interpolation: Default- Image.BICUBIC
526
527
528
529
        fill (n-tuple or int or float): Pixel fill value for area outside the rotated
            image. If int or float, the value is used for all bands respectively.
            This option is only available for ``pillow>=5.0.0``.

530
531
532
    Returns:
        PIL Image:  Perspectively transformed Image.
    """
533

vfdev's avatar
vfdev committed
534
    if not F_pil._is_pil_image(img):
535
536
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

537
538
    opts = _parse_fill(fill, img, '5.0.0')

539
    coeffs = _get_perspective_coeffs(startpoints, endpoints)
540
    return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation, **opts)
541
542


543
544
def vflip(img: Tensor) -> Tensor:
    """Vertically flip the given PIL Image or torch Tensor.
545
546

    Args:
vfdev's avatar
vfdev committed
547
        img (PIL Image or Tensor): Image to be flipped. If img
548
549
550
            is a Tensor, it is expected to be in [..., H, W] format,
            where ... means it can have an arbitrary number of trailing
            dimensions.
551
552
553
554

    Returns:
        PIL Image:  Vertically flipped image.
    """
555
556
    if not isinstance(img, torch.Tensor):
        return F_pil.vflip(img)
557

558
    return F_t.vflip(img)
559
560


vfdev's avatar
vfdev committed
561
562
563
564
def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
    """Crop the given image into four corners and the central crop.
    The image can be a PIL Image or a Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
565
566
567
568
569
570

    .. Note::
        This transform returns a tuple of images and there may be a
        mismatch in the number of inputs and targets your ``Dataset`` returns.

    Args:
vfdev's avatar
vfdev committed
571
572
573
574
        img (PIL Image or Tensor): Image to be cropped.
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
575

576
    Returns:
577
578
       tuple: tuple (tl, tr, bl, br, center)
                Corresponding top left, top right, bottom left, bottom right and center crop.
579
580
581
    """
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
vfdev's avatar
vfdev committed
582
583
    elif isinstance(size, (tuple, list)) and len(size) == 1:
        size = (size[0], size[0])
584

vfdev's avatar
vfdev committed
585
586
587
588
    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")

    image_width, image_height = _get_image_size(img)
589
590
591
592
593
    crop_height, crop_width = size
    if crop_width > image_width or crop_height > image_height:
        msg = "Requested crop size {} is bigger than input size {}"
        raise ValueError(msg.format(size, (image_height, image_width)))

vfdev's avatar
vfdev committed
594
595
596
597
598
599
600
601
    tl = crop(img, 0, 0, crop_height, crop_width)
    tr = crop(img, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop(img, image_height - crop_height, 0, crop_height, crop_width)
    br = crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)

    center = center_crop(img, [crop_height, crop_width])

    return tl, tr, bl, br, center
602
603


vfdev's avatar
vfdev committed
604
605
606
def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[Tensor]:
    """Generate ten cropped images from the given image.
    Crop the given image into four corners and the central crop plus the
607
    flipped version of these (horizontal flipping is used by default).
vfdev's avatar
vfdev committed
608
609
    The image can be a PIL Image or a Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
610
611
612
613
614

    .. Note::
        This transform returns a tuple of images and there may be a
        mismatch in the number of inputs and targets your ``Dataset`` returns.

615
    Args:
vfdev's avatar
vfdev committed
616
        img (PIL Image or Tensor): Image to be cropped.
617
        size (sequence or int): Desired output size of the crop. If size is an
618
            int instead of sequence like (h, w), a square crop (size, size) is
vfdev's avatar
vfdev committed
619
            made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
620
        vertical_flip (bool): Use vertical flipping instead of horizontal
621
622

    Returns:
623
624
625
        tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
            Corresponding top left, top right, bottom left, bottom right and
            center crop and same for the flipped image.
626
627
628
    """
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
vfdev's avatar
vfdev committed
629
630
631
632
633
    elif isinstance(size, (tuple, list)) and len(size) == 1:
        size = (size[0], size[0])

    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")
634
635
636
637
638
639
640
641
642
643
644
645

    first_five = five_crop(img, size)

    if vertical_flip:
        img = vflip(img)
    else:
        img = hflip(img)

    second_five = five_crop(img, size)
    return first_five + second_five


646
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
647
648
649
    """Adjust brightness of an Image.

    Args:
vfdev's avatar
vfdev committed
650
        img (PIL Image or Tensor): Image to be adjusted.
651
652
653
654
655
        brightness_factor (float):  How much to adjust the brightness. Can be
            any non negative number. 0 gives a black image, 1 gives the
            original image while 2 increases the brightness by a factor of 2.

    Returns:
vfdev's avatar
vfdev committed
656
        PIL Image or Tensor: Brightness adjusted image.
657
    """
658
659
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_brightness(img, brightness_factor)
660

661
    return F_t.adjust_brightness(img, brightness_factor)
662
663


664
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
665
666
667
    """Adjust contrast of an Image.

    Args:
vfdev's avatar
vfdev committed
668
        img (PIL Image or Tensor): Image to be adjusted.
669
670
671
672
673
        contrast_factor (float): How much to adjust the contrast. Can be any
            non negative number. 0 gives a solid gray image, 1 gives the
            original image while 2 increases the contrast by a factor of 2.

    Returns:
vfdev's avatar
vfdev committed
674
        PIL Image or Tensor: Contrast adjusted image.
675
    """
676
677
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_contrast(img, contrast_factor)
678

679
    return F_t.adjust_contrast(img, contrast_factor)
680
681


682
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
683
684
685
    """Adjust color saturation of an image.

    Args:
vfdev's avatar
vfdev committed
686
        img (PIL Image or Tensor): Image to be adjusted.
687
688
689
690
691
        saturation_factor (float):  How much to adjust the saturation. 0 will
            give a black and white image, 1 will give the original image while
            2 will enhance the saturation by a factor of 2.

    Returns:
vfdev's avatar
vfdev committed
692
        PIL Image or Tensor: Saturation adjusted image.
693
    """
694
695
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_saturation(img, saturation_factor)
696

697
    return F_t.adjust_saturation(img, saturation_factor)
698
699


700
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
701
702
703
704
705
706
707
708
709
    """Adjust hue of an image.

    The image hue is adjusted by converting the image to HSV and
    cyclically shifting the intensities in the hue channel (H).
    The image is then converted back to original image mode.

    `hue_factor` is the amount of shift in H channel and must be in the
    interval `[-0.5, 0.5]`.

710
711
712
    See `Hue`_ for more details.

    .. _Hue: https://en.wikipedia.org/wiki/Hue
713
714
715
716
717
718
719
720
721
722
723
724

    Args:
        img (PIL Image): PIL Image to be adjusted.
        hue_factor (float):  How much to shift the hue channel. Should be in
            [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
            HSV space in positive and negative direction respectively.
            0 means no shift. Therefore, both -0.5 and 0.5 will give an image
            with complementary colors while 0 gives the original image.

    Returns:
        PIL Image: Hue adjusted image.
    """
725
726
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_hue(img, hue_factor)
727

728
    raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
729
730


731
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
732
    r"""Perform gamma correction on an image.
733
734
735
736

    Also known as Power Law Transform. Intensities in RGB mode are adjusted
    based on the following equation:

737
738
739
740
    .. math::
        I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}

    See `Gamma Correction`_ for more details.
741

742
    .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
743
744

    Args:
745
        img (PIL Image or Tensor): PIL Image to be adjusted.
746
747
748
        gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
            gamma larger than 1 make the shadows darker,
            while gamma smaller than 1 make dark regions lighter.
749
        gain (float): The constant multiplier.
750
751
    Returns:
        PIL Image or Tensor: Gamma correction adjusted image.
752
    """
753
754
    if not isinstance(img, torch.Tensor):
        return F_pil.adjust_gamma(img, gamma, gain)
755

756
    return F_t.adjust_gamma(img, gamma, gain)
757
758


vfdev's avatar
vfdev committed
759
def _get_inverse_affine_matrix(
vfdev's avatar
vfdev committed
760
        center: List[float], angle: float, translate: List[float], scale: float, shear: List[float]
vfdev's avatar
vfdev committed
761
) -> List[float]:
762
763
764
765
766
767
768
    # Helper method to compute inverse matrix for affine transformation

    # As it is explained in PIL.Image.rotate
    # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
    # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
    #       C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
    #       RSS is rotation with scale and shear matrix
769
770
771
772
773
774
775
776
777
778
    #       RSS(a, s, (sx, sy)) =
    #       = R(a) * S(s) * SHy(sy) * SHx(sx)
    #       = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ]
    #         [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ]
    #         [ 0                    , 0                                      , 1 ]
    #
    # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
    # SHx(s) = [1, -tan(s)] and SHy(s) = [1      , 0]
    #          [0, 1      ]              [-tan(s), 1]
    #
779
780
    # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1

781
782
783
784
785
786
787
    rot = math.radians(angle)
    sx, sy = [math.radians(s) for s in shear]

    cx, cy = center
    tx, ty = translate

    # RSS without scaling
vfdev's avatar
vfdev committed
788
789
790
791
    a = math.cos(rot - sy) / math.cos(sy)
    b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
    c = math.sin(rot - sy) / math.cos(sy)
    d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
792
793

    # Inverted rotation matrix with scale and shear
794
    # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
vfdev's avatar
vfdev committed
795
796
    matrix = [d, -b, 0.0, -c, a, 0.0]
    matrix = [x / scale for x in matrix]
797
798

    # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
vfdev's avatar
vfdev committed
799
800
    matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
    matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
801
802

    # Apply center translation: C * RSS^-1 * C^-1 * T^-1
vfdev's avatar
vfdev committed
803
804
    matrix[2] += cx
    matrix[5] += cy
805

vfdev's avatar
vfdev committed
806
    return matrix
807

vfdev's avatar
vfdev committed
808

vfdev's avatar
vfdev committed
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
def rotate(
        img: Tensor, angle: float, resample: int = 0, expand: bool = False,
        center: Optional[List[int]] = None, fill: Optional[int] = None
) -> Tensor:
    """Rotate the image by angle.
    The image can be a PIL Image or a Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.

    Args:
        img (PIL Image or Tensor): image to be rotated.
        angle (float or int): rotation angle value in degrees, counter-clockwise.
        resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
            An optional resampling filter. See `filters`_ for more information.
            If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
        expand (bool, optional): Optional expansion flag.
            If true, expands the output image to make it large enough to hold the entire rotated image.
            If false or omitted, make the output image the same size as the input image.
            Note that the expand flag assumes rotation around the center and no translation.
        center (list or tuple, optional): Optional center of rotation. Origin is the upper left corner.
            Default is the center of the image.
        fill (n-tuple or int or float): Pixel fill value for area outside the rotated
            image. If int or float, the value is used for all bands respectively.
            Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.

    Returns:
        PIL Image or Tensor: Rotated image.

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

    """
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if center is not None and not isinstance(center, (list, tuple)):
        raise TypeError("Argument center should be a sequence")

    if not isinstance(img, torch.Tensor):
        return F_pil.rotate(img, angle=angle, resample=resample, expand=expand, center=center, fill=fill)

    center_f = [0.0, 0.0]
    if center is not None:
        img_size = _get_image_size(img)
851
852
853
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
        center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]

vfdev's avatar
vfdev committed
854
855
856
857
858
859
    # due to current incoherence of rotation angle direction between affine and rotate implementations
    # we need to set -angle.
    matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
    return F_t.rotate(img, matrix=matrix, resample=resample, expand=expand, fill=fill)


vfdev's avatar
vfdev committed
860
861
862
863
864
865
866
def affine(
        img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float],
        resample: int = 0, fillcolor: Optional[int] = None
) -> Tensor:
    """Apply affine transformation on the image keeping image center invariant.
    The image can be a PIL Image or a Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
867
868

    Args:
vfdev's avatar
vfdev committed
869
        img (PIL Image or Tensor): image to transform.
870
        angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction.
871
872
        translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation)
        scale (float): overall scale
ptrblck's avatar
ptrblck committed
873
        shear (float or tuple or list): shear angle value in degrees between -180 to 180, clockwise direction.
vfdev's avatar
vfdev committed
874
875
            If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while
            the second value corresponds to a shear parallel to the y axis.
876
        resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
vfdev's avatar
vfdev committed
877
878
879
            An optional resampling filter. See `filters`_ for more information.
            If omitted, or if the image is PIL Image and has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
            If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported.
880
881
882
        fillcolor (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0).
            This option is not supported for Tensor input. Fill value for the area outside the transform in the output
            image is always 0.
vfdev's avatar
vfdev committed
883
884
885

    Returns:
        PIL Image or Tensor: Transformed image.
886
    """
vfdev's avatar
vfdev committed
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if not isinstance(translate, (list, tuple)):
        raise TypeError("Argument translate should be a sequence")

    if len(translate) != 2:
        raise ValueError("Argument translate should be a sequence of length 2")

    if scale <= 0.0:
        raise ValueError("Argument scale should be positive")

    if not isinstance(shear, (numbers.Number, (list, tuple))):
        raise TypeError("Shear should be either a single value or a sequence of two values")

    if isinstance(angle, int):
        angle = float(angle)

    if isinstance(translate, tuple):
        translate = list(translate)

    if isinstance(shear, numbers.Number):
        shear = [shear, 0.0]

    if isinstance(shear, tuple):
        shear = list(shear)

    if len(shear) == 1:
        shear = [shear[0], shear[0]]

    if len(shear) != 2:
        raise ValueError("Shear should be a sequence containing two values. Got {}".format(shear))

    img_size = _get_image_size(img)
    if not isinstance(img, torch.Tensor):
        # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
        # it is visually better to estimate the center without 0.5 offset
        # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
        center = [img_size[0] * 0.5, img_size[1] * 0.5]
        matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
927

vfdev's avatar
vfdev committed
928
        return F_pil.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)
929

930
931
    translate_f = [1.0 * t for t in translate]
    matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear)
vfdev's avatar
vfdev committed
932
    return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)
933
934


935
936
937
938
939
940
941
def to_grayscale(img, num_output_channels=1):
    """Convert image to grayscale version of image.

    Args:
        img (PIL Image): Image to be converted to grayscale.

    Returns:
942
943
944
945
        PIL Image: Grayscale version of the image.
            if num_output_channels = 1 : returned image is single channel

            if num_output_channels = 3 : returned image is 3 channel with r = g = b
946
    """
vfdev's avatar
vfdev committed
947
    if not F_pil._is_pil_image(img):
948
949
950
951
952
953
954
955
956
957
958
959
960
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    if num_output_channels == 1:
        img = img.convert('L')
    elif num_output_channels == 3:
        img = img.convert('L')
        np_img = np.array(img, dtype=np.uint8)
        np_img = np.dstack([np_img, np_img, np_img])
        img = Image.fromarray(np_img, 'RGB')
    else:
        raise ValueError('num_output_channels should be either 1 or 3')

    return img
961
962


963
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
964
965
966
967
968
969
970
971
972
    """ Erase the input Tensor Image with given value.

    Args:
        img (Tensor Image): Tensor image of size (C, H, W) to be erased
        i (int): i in (i,j) i.e coordinates of the upper left corner.
        j (int): j in (i,j) i.e coordinates of the upper left corner.
        h (int): Height of the erased region.
        w (int): Width of the erased region.
        v: Erasing value.
Zhun Zhong's avatar
Zhun Zhong committed
973
        inplace(bool, optional): For in-place operations. By default is set False.
974
975
976
977
978
979
980

    Returns:
        Tensor Image: Erased image.
    """
    if not isinstance(img, torch.Tensor):
        raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))

981
982
983
    if not inplace:
        img = img.clone()

984
985
    img[:, i:i + h, j:j + w] = v
    return img