_geometry.py 78 KB
Newer Older
1
import math
2
import numbers
3
import warnings
4
from typing import List, Optional, Sequence, Tuple, Union
5
6
7

import PIL.Image
import torch
8
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
9

10
from torchvision import datapoints
11
12
from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _pad_symmetric
13
from torchvision.transforms.functional import (
14
    _check_antialias,
15
    _compute_resized_output_size as __compute_resized_output_size,
16
    _get_perspective_coeffs,
17
    _interpolation_modes_from_int,
18
    InterpolationMode,
19
    pil_modes_mapping,
20
21
    pil_to_tensor,
    to_pil_image,
22
)
23

24
25
from torchvision.utils import _log_api_usage_once

26
from ._meta import clamp_bounding_box, convert_format_bounding_box, get_spatial_size_image_pil
27

28
29
from ._utils import is_simple_tensor

30

31
32
33
34
35
36
37
38
39
40
41
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
    if isinstance(interpolation, int):
        interpolation = _interpolation_modes_from_int(interpolation)
    elif not isinstance(interpolation, InterpolationMode):
        raise ValueError(
            f"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, "
            f"but got {interpolation}."
        )
    return interpolation


42
43
44
45
def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
    return image.flip(-1)


46
47
def horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
    return _FP.hflip(image)
48
49


50
51
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
    return horizontal_flip_image_tensor(mask)
52
53


54
def horizontal_flip_bounding_box(
55
    bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int]
56
57
58
) -> torch.Tensor:
    shape = bounding_box.shape

59
    bounding_box = bounding_box.clone().reshape(-1, 4)
60

61
    if format == datapoints.BoundingBoxFormat.XYXY:
62
        bounding_box[:, [2, 0]] = bounding_box[:, [0, 2]].sub_(spatial_size[1]).neg_()
63
    elif format == datapoints.BoundingBoxFormat.XYWH:
64
        bounding_box[:, 0].add_(bounding_box[:, 2]).sub_(spatial_size[1]).neg_()
65
    else:  # format == datapoints.BoundingBoxFormat.CXCYWH:
66
        bounding_box[:, 0].sub_(spatial_size[1]).neg_()
67

68
    return bounding_box.reshape(shape)
69
70


71
72
73
74
def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
    return horizontal_flip_image_tensor(video)


Philip Meier's avatar
Philip Meier committed
75
def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
76
77
78
    if not torch.jit.is_scripting():
        _log_api_usage_once(horizontal_flip)

79
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
80
        return horizontal_flip_image_tensor(inpt)
81
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
82
        return inpt.horizontal_flip()
83
    elif isinstance(inpt, PIL.Image.Image):
84
        return horizontal_flip_image_pil(inpt)
85
86
    else:
        raise TypeError(
87
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
88
89
            f"but got {type(inpt)} instead."
        )
90
91


92
93
94
95
def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
    return image.flip(-2)


Philip Meier's avatar
Philip Meier committed
96
97
def vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
    return _FP.vflip(image)
98
99


100
101
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
    return vertical_flip_image_tensor(mask)
102
103
104


def vertical_flip_bounding_box(
105
    bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int]
106
107
108
) -> torch.Tensor:
    shape = bounding_box.shape

109
    bounding_box = bounding_box.clone().reshape(-1, 4)
110

111
    if format == datapoints.BoundingBoxFormat.XYXY:
112
        bounding_box[:, [1, 3]] = bounding_box[:, [3, 1]].sub_(spatial_size[0]).neg_()
113
    elif format == datapoints.BoundingBoxFormat.XYWH:
114
        bounding_box[:, 1].add_(bounding_box[:, 3]).sub_(spatial_size[0]).neg_()
115
    else:  # format == datapoints.BoundingBoxFormat.CXCYWH:
116
        bounding_box[:, 1].sub_(spatial_size[0]).neg_()
117

118
    return bounding_box.reshape(shape)
119
120


121
122
123
124
def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
    return vertical_flip_image_tensor(video)


Philip Meier's avatar
Philip Meier committed
125
def vertical_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
126
127
128
    if not torch.jit.is_scripting():
        _log_api_usage_once(vertical_flip)

129
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
130
        return vertical_flip_image_tensor(inpt)
131
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
132
        return inpt.vertical_flip()
133
    elif isinstance(inpt, PIL.Image.Image):
134
        return vertical_flip_image_pil(inpt)
135
136
    else:
        raise TypeError(
137
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
138
139
            f"but got {type(inpt)} instead."
        )
140
141


142
143
144
145
146
147
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
hflip = horizontal_flip
vflip = vertical_flip


148
def _compute_resized_output_size(
149
    spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
150
151
152
) -> List[int]:
    if isinstance(size, int):
        size = [size]
153
154
155
156
157
    elif max_size is not None and len(size) != 1:
        raise ValueError(
            "max_size should only be passed if size specifies the length of the smaller edge, "
            "i.e. size should be an int or a sequence of length 1 in torchscript mode."
        )
158
    return __compute_resized_output_size(spatial_size, size=size, max_size=max_size)
159
160


161
162
163
def resize_image_tensor(
    image: torch.Tensor,
    size: List[int],
164
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
165
    max_size: Optional[int] = None,
166
    antialias: Optional[Union[str, bool]] = "warn",
167
) -> torch.Tensor:
168
    interpolation = _check_interpolation(interpolation)
169
170
    antialias = _check_antialias(img=image, antialias=antialias, interpolation=interpolation)
    assert not isinstance(antialias, str)
171
    antialias = False if antialias is None else antialias
172
173
174
    align_corners: Optional[bool] = None
    if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC:
        align_corners = False
175
176
177
178
    else:
        # The default of antialias should be True from 0.17, so we don't warn or
        # error if other interpolation modes are used. This is documented.
        antialias = False
179

180
    shape = image.shape
181
    numel = image.numel()
182
    num_channels, old_height, old_width = shape[-3:]
vfdev's avatar
vfdev committed
183
    new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
184

185
186
    if (new_height, new_width) == (old_height, old_width):
        return image
187
    elif numel > 0:
188
        image = image.reshape(-1, num_channels, old_height, old_width)
189

190
        dtype = image.dtype
191
192
193
194
        acceptable_dtypes = [torch.float32, torch.float64]
        if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
            # uint8 dtype can be included for cpu and cuda input if nearest mode
            acceptable_dtypes.append(torch.uint8)
195
196
197
198
199
200
201
        elif image.device.type == "cpu":
            # uint8 dtype support for bilinear and bicubic is limited to cpu and
            # according to our benchmarks, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilinear
            if (interpolation == InterpolationMode.BILINEAR and "AVX2" in torch.backends.cpu.get_cpu_capability()) or (
                interpolation == InterpolationMode.BICUBIC
            ):
                acceptable_dtypes.append(torch.uint8)
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

        strides = image.stride()
        if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
            # There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as
            # contiguous even though the input is un-ambiguously channels_last (https://github.com/pytorch/pytorch/issues/68430).
            # In particular this happens for the typical torchvision use-case of single CHW images where we fake the batch dim
            # to become 1CHW. Below, we restride those tensors to trick torch core into properly allocating the output as
            # channels_last, thus preserving the memory format of the input. This is not just for format consistency:
            # for uint8 bilinear images, this also avoids an extra copy (re-packing) of the output and saves time.
            # TODO: when https://github.com/pytorch/pytorch/issues/68430 is fixed (possibly by https://github.com/pytorch/pytorch/pull/100373),
            # we should be able to remove this hack.
            new_strides = list(strides)
            new_strides[0] = numel
            image = image.as_strided((1, num_channels, old_height, old_width), new_strides)

        need_cast = dtype not in acceptable_dtypes
218
219
220
221
        if need_cast:
            image = image.to(dtype=torch.float32)

        image = interpolate(
222
223
            image,
            size=[new_height, new_width],
224
225
            mode=interpolation.value,
            align_corners=align_corners,
226
227
            antialias=antialias,
        )
228

229
230
        if need_cast:
            if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
231
                # This path is hit on non-AVX archs, or on GPU.
232
                image = image.clamp_(min=0, max=255)
233
234
235
            if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
                image = image.round_()
            image = image.to(dtype=dtype)
236

237
    return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
238
239


240
@torch.jit.unused
241
def resize_image_pil(
242
    image: PIL.Image.Image,
243
    size: Union[Sequence[int], int],
244
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
245
246
    max_size: Optional[int] = None,
) -> PIL.Image.Image:
247
248
249
250
251
252
253
    old_height, old_width = image.height, image.width
    new_height, new_width = _compute_resized_output_size(
        (old_height, old_width),
        size=size,  # type: ignore[arg-type]
        max_size=max_size,
    )

254
    interpolation = _check_interpolation(interpolation)
255
256
257
258
259

    if (new_height, new_width) == (old_height, old_width):
        return image

    return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation])
260
261


262
263
264
def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
265
266
267
268
        needs_squeeze = True
    else:
        needs_squeeze = False

269
    output = resize_image_tensor(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
270
271
272
273
274

    if needs_squeeze:
        output = output.squeeze(0)

    return output
275
276


277
def resize_bounding_box(
278
    bounding_box: torch.Tensor, spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
279
) -> Tuple[torch.Tensor, Tuple[int, int]]:
280
281
    old_height, old_width = spatial_size
    new_height, new_width = _compute_resized_output_size(spatial_size, size=size, max_size=max_size)
282
283
284
285

    if (new_height, new_width) == (old_height, old_width):
        return bounding_box, spatial_size

286
287
288
    w_ratio = new_width / old_width
    h_ratio = new_height / old_height
    ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_box.device)
289
    return (
290
        bounding_box.mul(ratios).to(bounding_box.dtype),
291
292
        (new_height, new_width),
    )
293
294


295
296
297
def resize_video(
    video: torch.Tensor,
    size: List[int],
298
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
299
    max_size: Optional[int] = None,
300
    antialias: Optional[Union[str, bool]] = "warn",
301
302
303
304
) -> torch.Tensor:
    return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)


305
def resize(
Philip Meier's avatar
Philip Meier committed
306
    inpt: datapoints._InputTypeJIT,
307
    size: List[int],
308
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
309
    max_size: Optional[int] = None,
310
    antialias: Optional[Union[str, bool]] = "warn",
Philip Meier's avatar
Philip Meier committed
311
) -> datapoints._InputTypeJIT:
312
313
    if not torch.jit.is_scripting():
        _log_api_usage_once(resize)
314
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
315
        return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
316
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
317
        return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
318
    elif isinstance(inpt, PIL.Image.Image):
319
        if antialias is False:
320
321
            warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
        return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size)
322
323
    else:
        raise TypeError(
324
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
325
326
            f"but got {type(inpt)} instead."
        )
327
328


329
def _affine_parse_args(
330
    angle: Union[int, float],
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    translate: List[float],
    scale: float,
    shear: List[float],
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    center: Optional[List[float]] = None,
) -> Tuple[float, List[float], List[float], Optional[List[float]]]:
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if not isinstance(translate, (list, tuple)):
        raise TypeError("Argument translate should be a sequence")

    if len(translate) != 2:
        raise ValueError("Argument translate should be a sequence of length 2")

    if scale <= 0.0:
        raise ValueError("Argument scale should be positive")

    if not isinstance(shear, (numbers.Number, (list, tuple))):
        raise TypeError("Shear should be either a single value or a sequence of two values")

    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")

    if isinstance(angle, int):
        angle = float(angle)

    if isinstance(translate, tuple):
        translate = list(translate)

    if isinstance(shear, numbers.Number):
        shear = [shear, 0.0]

    if isinstance(shear, tuple):
        shear = list(shear)

    if len(shear) == 1:
        shear = [shear[0], shear[0]]

    if len(shear) != 2:
        raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")

373
374
375
376
377
    if center is not None:
        if not isinstance(center, (list, tuple)):
            raise TypeError("Argument center should be a sequence")
        else:
            center = [float(c) for c in center]
378
379
380
381

    return angle, translate, shear, center


382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
def _get_inverse_affine_matrix(
    center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
) -> List[float]:
    # Helper method to compute inverse matrix for affine transformation

    # Pillow requires inverse affine transformation matrix:
    # Affine matrix is : M = T * C * RotateScaleShear * C^-1
    #
    # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
    #       C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
    #       RotateScaleShear is rotation with scale and shear matrix
    #
    #       RotateScaleShear(a, s, (sx, sy)) =
    #       = R(a) * S(s) * SHy(sy) * SHx(sx)
    #       = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
    #         [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
    #         [ 0                    , 0                                      , 1 ]
    # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
    # SHx(s) = [1, -tan(s)] and SHy(s) = [1      , 0]
    #          [0, 1      ]              [-tan(s), 1]
    #
    # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1

    rot = math.radians(angle)
    sx = math.radians(shear[0])
    sy = math.radians(shear[1])

    cx, cy = center
    tx, ty = translate

    # Cached results
    cos_sy = math.cos(sy)
    tan_sx = math.tan(sx)
    rot_minus_sy = rot - sy
    cx_plus_tx = cx + tx
    cy_plus_ty = cy + ty

    # Rotate Scale Shear (RSS) without scaling
    a = math.cos(rot_minus_sy) / cos_sy
    b = -(a * tan_sx + math.sin(rot))
    c = math.sin(rot_minus_sy) / cos_sy
    d = math.cos(rot) - c * tan_sx

    if inverted:
        # Inverted rotation matrix with scale and shear
        # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
        matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0]
        # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
        # and then apply center translation: C * RSS^-1 * C^-1 * T^-1
        matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty
        matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty
    else:
        matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0]
        # Apply inverse of center translation: RSS * C^-1
        # and then apply translation and center : T * C * RSS * C^-1
        matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy
        matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy

    return matrix


def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
    # Inspired of PIL implementation:
    # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054

    # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
    # Points are shifted due to affine matrix torch convention about
    # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
    half_w = 0.5 * w
    half_h = 0.5 * h
    pts = torch.tensor(
        [
            [-half_w, -half_h, 1.0],
            [-half_w, half_h, 1.0],
            [half_w, half_h, 1.0],
            [half_w, -half_h, 1.0],
        ]
    )
    theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
    new_pts = torch.matmul(pts, theta.T)
    min_vals, max_vals = new_pts.aminmax(dim=0)

    # shift points to [0, w] and [0, h] interval to match PIL results
    halfs = torch.tensor((half_w, half_h))
    min_vals.add_(halfs)
    max_vals.add_(halfs)

    # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
    tol = 1e-4
    inv_tol = 1.0 / tol
    cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_()
    cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_()
    size = cmax.sub_(cmin)
    return int(size[0]), int(size[1])  # w, h


def _apply_grid_transform(
Philip Meier's avatar
Philip Meier committed
479
    img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints._FillTypeJIT
480
481
) -> torch.Tensor:

482
483
484
485
    # We are using context knowledge that grid should have float dtype
    fp = img.dtype == grid.dtype
    float_img = img if fp else img.to(grid.dtype)

486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
    shape = float_img.shape
    if shape[0] > 1:
        # Apply same grid to a batch of images
        grid = grid.expand(shape[0], -1, -1, -1)

    # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
    if fill is not None:
        mask = torch.ones((shape[0], 1, shape[2], shape[3]), dtype=float_img.dtype, device=float_img.device)
        float_img = torch.cat((float_img, mask), dim=1)

    float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False)

    # Fill with required color
    if fill is not None:
        float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
        mask = mask.expand_as(float_img)
502
        fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)]  # type: ignore[arg-type]
503
504
505
506
507
508
509
510
511
        fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
        if mode == "nearest":
            bool_mask = mask < 0.5
            float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
        else:  # 'bilinear'
            # The following is mathematically equivalent to:
            # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
            float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)

512
513
514
    img = float_img.round_().to(img.dtype) if not fp else float_img

    return img
515
516
517
518
519
520


def _assert_grid_transform_inputs(
    image: torch.Tensor,
    matrix: Optional[List[float]],
    interpolation: str,
Philip Meier's avatar
Philip Meier committed
521
    fill: datapoints._FillTypeJIT,
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
    supported_interpolation_modes: List[str],
    coeffs: Optional[List[float]] = None,
) -> None:
    if matrix is not None:
        if not isinstance(matrix, list):
            raise TypeError("Argument matrix should be a list")
        elif len(matrix) != 6:
            raise ValueError("Argument matrix should have 6 float values")

    if coeffs is not None and len(coeffs) != 8:
        raise ValueError("Argument coeffs should have 8 float values")

    if fill is not None:
        if isinstance(fill, (tuple, list)):
            length = len(fill)
            num_channels = image.shape[-3]
            if length > 1 and length != num_channels:
                raise ValueError(
                    "The number of elements in 'fill' cannot broadcast to match the number of "
                    f"channels of the image ({length} != {num_channels})"
                )
        elif not isinstance(fill, (int, float)):
            raise ValueError("Argument fill should be either int, float, tuple or list")

    if interpolation not in supported_interpolation_modes:
        raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")


def _affine_grid(
    theta: torch.Tensor,
    w: int,
    h: int,
    ow: int,
    oh: int,
) -> torch.Tensor:
    # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
    # AffineGridGenerator.cpp#L18
    # Difference with AffineGridGenerator is that:
    # 1) we normalize grid values after applying theta
    # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
    dtype = theta.dtype
    device = theta.device

    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
    x_grid = torch.linspace((1.0 - ow) * 0.5, (ow - 1.0) * 0.5, steps=ow, device=device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace((1.0 - oh) * 0.5, (oh - 1.0) * 0.5, steps=oh, device=device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
    base_grid[..., 2].fill_(1)

    rescaled_theta = theta.transpose(1, 2).div_(torch.tensor([0.5 * w, 0.5 * h], dtype=dtype, device=device))
    output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
    return output_grid.view(1, oh, ow, 2)


577
def affine_image_tensor(
578
    image: torch.Tensor,
579
    angle: Union[int, float],
580
581
582
    translate: List[float],
    scale: float,
    shear: List[float],
583
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
584
    fill: datapoints._FillTypeJIT = None,
585
586
    center: Optional[List[float]] = None,
) -> torch.Tensor:
587
588
    interpolation = _check_interpolation(interpolation)

589
590
    if image.numel() == 0:
        return image
591

592
    shape = image.shape
593
    ndim = image.ndim
594

595
596
597
598
599
600
601
602
603
604
    if ndim > 4:
        image = image.reshape((-1,) + shape[-3:])
        needs_unsquash = True
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
    else:
        needs_unsquash = False

    height, width = shape[-2:]
605
606
607
608
609
    angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

    center_f = [0.0, 0.0]
    if center is not None:
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
610
        center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
611

612
    translate_f = [float(t) for t in translate]
613
614
    matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)

615
616
    _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])

617
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
618
619
    theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
    grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
620
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
621
622
623
624
625

    if needs_unsquash:
        output = output.reshape(shape)

    return output
626
627


628
@torch.jit.unused
629
def affine_image_pil(
630
    image: PIL.Image.Image,
631
    angle: Union[int, float],
632
633
634
    translate: List[float],
    scale: float,
    shear: List[float],
635
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
636
    fill: datapoints._FillTypeJIT = None,
637
638
    center: Optional[List[float]] = None,
) -> PIL.Image.Image:
639
    interpolation = _check_interpolation(interpolation)
640
641
642
643
644
645
    angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

    # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
    # it is visually better to estimate the center without 0.5 offset
    # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
    if center is None:
646
        height, width = get_spatial_size_image_pil(image)
647
648
649
        center = [width * 0.5, height * 0.5]
    matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)

650
    return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
651
652


653
def _affine_bounding_box_with_expand(
654
    bounding_box: torch.Tensor,
655
    format: datapoints.BoundingBoxFormat,
656
    spatial_size: Tuple[int, int],
657
658
659
660
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
661
    center: Optional[List[float]] = None,
662
    expand: bool = False,
663
) -> Tuple[torch.Tensor, Tuple[int, int]]:
664
665
666
    if bounding_box.numel() == 0:
        return bounding_box, spatial_size

667
668
669
670
671
672
673
674
675
676
677
    original_shape = bounding_box.shape
    original_dtype = bounding_box.dtype
    bounding_box = bounding_box.clone() if bounding_box.is_floating_point() else bounding_box.float()
    dtype = bounding_box.dtype
    device = bounding_box.device
    bounding_box = (
        convert_format_bounding_box(
            bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
        )
    ).reshape(-1, 4)

678
679
680
    angle, translate, shear, center = _affine_parse_args(
        angle, translate, scale, shear, InterpolationMode.NEAREST, center
    )
681

682
    if center is None:
683
        height, width = spatial_size
684
685
        center = [width * 0.5, height * 0.5]

686
687
688
689
690
691
692
    affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
    transposed_affine_matrix = (
        torch.tensor(
            affine_vector,
            dtype=dtype,
            device=device,
        )
693
        .reshape(2, 3)
694
695
        .T
    )
696
697
698
699
    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
700
    points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
701
    points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
702
    # 2) Now let's transform the points using affine matrix
703
    transformed_points = torch.matmul(points, transposed_affine_matrix)
704
705
    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
706
    transformed_points = transformed_points.reshape(-1, 4, 2)
707
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
708
    out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
709
710
711
712

    if expand:
        # Compute minimum point for transformed image frame:
        # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
713
        height, width = spatial_size
714
715
716
        points = torch.tensor(
            [
                [0.0, 0.0, 1.0],
717
718
719
                [0.0, float(height), 1.0],
                [float(width), float(height), 1.0],
                [float(width), 0.0, 1.0],
720
721
722
723
            ],
            dtype=dtype,
            device=device,
        )
724
        new_points = torch.matmul(points, transposed_affine_matrix)
725
        tr = torch.amin(new_points, dim=0, keepdim=True)
726
        # Translate bounding boxes
727
        out_bboxes.sub_(tr.repeat((1, 2)))
728
729
        # Estimate meta-data for image with inverted=True and with center=[0,0]
        affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
730
        new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
731
        spatial_size = (new_height, new_width)
732

733
734
735
736
737
738
739
    out_bboxes = clamp_bounding_box(out_bboxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size)
    out_bboxes = convert_format_bounding_box(
        out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
    ).reshape(original_shape)

    out_bboxes = out_bboxes.to(original_dtype)
    return out_bboxes, spatial_size
740
741
742
743


def affine_bounding_box(
    bounding_box: torch.Tensor,
744
    format: datapoints.BoundingBoxFormat,
745
    spatial_size: Tuple[int, int],
746
    angle: Union[int, float],
747
748
749
750
751
    translate: List[float],
    scale: float,
    shear: List[float],
    center: Optional[List[float]] = None,
) -> torch.Tensor:
752
753
754
755
756
757
758
759
760
761
762
763
    out_box, _ = _affine_bounding_box_with_expand(
        bounding_box,
        format=format,
        spatial_size=spatial_size,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        center=center,
        expand=False,
    )
    return out_box
764
765


766
767
def affine_mask(
    mask: torch.Tensor,
768
    angle: Union[int, float],
769
770
771
    translate: List[float],
    scale: float,
    shear: List[float],
Philip Meier's avatar
Philip Meier committed
772
    fill: datapoints._FillTypeJIT = None,
773
774
    center: Optional[List[float]] = None,
) -> torch.Tensor:
775
776
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
777
778
779
780
781
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = affine_image_tensor(
782
        mask,
783
784
785
786
787
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=InterpolationMode.NEAREST,
788
        fill=fill,
789
790
791
        center=center,
    )

792
793
794
795
796
    if needs_squeeze:
        output = output.squeeze(0)

    return output

797

798
799
800
801
802
803
def affine_video(
    video: torch.Tensor,
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
804
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
805
    fill: datapoints._FillTypeJIT = None,
806
807
808
809
810
811
812
813
814
815
816
817
818
819
    center: Optional[List[float]] = None,
) -> torch.Tensor:
    return affine_image_tensor(
        video,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=interpolation,
        fill=fill,
        center=center,
    )


820
def affine(
Philip Meier's avatar
Philip Meier committed
821
    inpt: datapoints._InputTypeJIT,
822
    angle: Union[int, float],
823
824
825
    translate: List[float],
    scale: float,
    shear: List[float],
826
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
827
    fill: datapoints._FillTypeJIT = None,
828
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
829
) -> datapoints._InputTypeJIT:
830
831
832
    if not torch.jit.is_scripting():
        _log_api_usage_once(affine)

833
    # TODO: consider deprecating integers from angle and shear on the future
834
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
835
        return affine_image_tensor(
836
837
838
839
840
841
842
843
844
            inpt,
            angle,
            translate=translate,
            scale=scale,
            shear=shear,
            interpolation=interpolation,
            fill=fill,
            center=center,
        )
845
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
846
847
848
        return inpt.affine(
            angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center
        )
849
    elif isinstance(inpt, PIL.Image.Image):
850
        return affine_image_pil(
851
852
853
854
855
856
857
858
859
            inpt,
            angle,
            translate=translate,
            scale=scale,
            shear=shear,
            interpolation=interpolation,
            fill=fill,
            center=center,
        )
860
861
    else:
        raise TypeError(
862
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
863
864
            f"but got {type(inpt)} instead."
        )
865
866


867
def rotate_image_tensor(
868
    image: torch.Tensor,
869
    angle: float,
870
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
871
872
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
873
    fill: datapoints._FillTypeJIT = None,
874
) -> torch.Tensor:
875
876
    interpolation = _check_interpolation(interpolation)

877
878
    shape = image.shape
    num_channels, height, width = shape[-3:]
879

880
881
    center_f = [0.0, 0.0]
    if center is not None:
882
        if expand:
883
            # TODO: Do we actually want to warn, or just document this?
884
            warnings.warn("The provided center argument has no effect on the result if expand is True")
885
886
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
        center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
887
888
889
890

    # due to current incoherence of rotation angle direction between affine and rotate implementations
    # we need to set -angle.
    matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
891

892
    if image.numel() > 0:
893
894
895
896
897
        image = image.reshape(-1, num_channels, height, width)

        _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])

        ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
898
        dtype = image.dtype if torch.is_floating_point(image) else torch.float32
899
900
        theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
        grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh)
901
        output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
902
903

        new_height, new_width = output.shape[-2:]
904
    else:
905
906
        output = image
        new_width, new_height = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
907

908
    return output.reshape(shape[:-3] + (num_channels, new_height, new_width))
909
910


911
@torch.jit.unused
912
def rotate_image_pil(
913
    image: PIL.Image.Image,
914
    angle: float,
915
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
916
917
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
918
    fill: datapoints._FillTypeJIT = None,
919
) -> PIL.Image.Image:
920
921
    interpolation = _check_interpolation(interpolation)

922
    if center is not None and expand:
923
        warnings.warn("The provided center argument has no effect on the result if expand is True")
924
925
        center = None

926
    return _FP.rotate(
927
        image, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
928
929
930
    )


931
932
def rotate_bounding_box(
    bounding_box: torch.Tensor,
933
    format: datapoints.BoundingBoxFormat,
934
    spatial_size: Tuple[int, int],
935
936
937
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
938
) -> Tuple[torch.Tensor, Tuple[int, int]]:
939
940
941
942
    if center is not None and expand:
        warnings.warn("The provided center argument has no effect on the result if expand is True")
        center = None

943
    return _affine_bounding_box_with_expand(
944
        bounding_box,
945
946
        format=format,
        spatial_size=spatial_size,
947
948
949
950
951
952
953
        angle=-angle,
        translate=[0.0, 0.0],
        scale=1.0,
        shear=[0.0, 0.0],
        center=center,
        expand=expand,
    )
954
955


956
957
def rotate_mask(
    mask: torch.Tensor,
958
959
960
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
961
    fill: datapoints._FillTypeJIT = None,
962
) -> torch.Tensor:
963
964
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
965
966
967
968
969
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = rotate_image_tensor(
970
        mask,
971
972
973
        angle=angle,
        expand=expand,
        interpolation=InterpolationMode.NEAREST,
974
        fill=fill,
975
976
977
        center=center,
    )

978
979
980
981
982
    if needs_squeeze:
        output = output.squeeze(0)

    return output

983

984
985
986
def rotate_video(
    video: torch.Tensor,
    angle: float,
987
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
988
989
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
990
    fill: datapoints._FillTypeJIT = None,
991
992
993
994
) -> torch.Tensor:
    return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)


995
def rotate(
Philip Meier's avatar
Philip Meier committed
996
    inpt: datapoints._InputTypeJIT,
997
    angle: float,
998
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
999
1000
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
1001
1002
    fill: datapoints._FillTypeJIT = None,
) -> datapoints._InputTypeJIT:
1003
1004
1005
    if not torch.jit.is_scripting():
        _log_api_usage_once(rotate)

1006
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1007
        return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
1008
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1009
        return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
1010
    elif isinstance(inpt, PIL.Image.Image):
1011
        return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
1012
1013
    else:
        raise TypeError(
1014
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1015
1016
            f"but got {type(inpt)} instead."
        )
1017
1018


1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
    if isinstance(padding, int):
        pad_left = pad_right = pad_top = pad_bottom = padding
    elif isinstance(padding, (tuple, list)):
        if len(padding) == 1:
            pad_left = pad_right = pad_top = pad_bottom = padding[0]
        elif len(padding) == 2:
            pad_left = pad_right = padding[0]
            pad_top = pad_bottom = padding[1]
        elif len(padding) == 4:
            pad_left = padding[0]
            pad_top = padding[1]
            pad_right = padding[2]
            pad_bottom = padding[3]
        else:
            raise ValueError(
                f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
            )
    else:
        raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}")

    return [pad_left, pad_right, pad_top, pad_bottom]
1041

1042

1043
def pad_image_tensor(
1044
    image: torch.Tensor,
1045
1046
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1047
1048
    padding_mode: str = "constant",
) -> torch.Tensor:
1049
1050
1051
1052
1053
    # Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses
    # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
    # internally.
    torch_padding = _parse_pad_padding(padding)

1054
    if padding_mode not in ("constant", "edge", "reflect", "symmetric"):
1055
1056
1057
1058
1059
        raise ValueError(
            f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
            f"but got `'{padding_mode}'`."
        )

1060
    if fill is None:
1061
1062
1063
1064
1065
1066
        fill = 0

    if isinstance(fill, (int, float)):
        return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
    elif len(fill) == 1:
        return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode)
1067
    else:
1068
        return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
1069
1070
1071


def _pad_with_scalar_fill(
1072
    image: torch.Tensor,
1073
1074
1075
    torch_padding: List[int],
    fill: Union[int, float],
    padding_mode: str,
1076
) -> torch.Tensor:
1077
1078
    shape = image.shape
    num_channels, height, width = shape[-3:]
1079

1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
    batch_size = 1
    for s in shape[:-3]:
        batch_size *= s

    image = image.reshape(batch_size, num_channels, height, width)

    if padding_mode == "edge":
        # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
        # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
        # name.
        padding_mode = "replicate"

    if padding_mode == "constant":
        image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
    elif padding_mode in ("reflect", "replicate"):
        # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
        # TODO: See https://github.com/pytorch/pytorch/issues/40763
        dtype = image.dtype
        if not image.is_floating_point():
            needs_cast = True
            image = image.to(torch.float32)
        else:
            needs_cast = False
1103

1104
1105
1106
1107
1108
        image = torch_pad(image, torch_padding, mode=padding_mode)

        if needs_cast:
            image = image.to(dtype)
    else:  # padding_mode == "symmetric"
1109
        image = _pad_symmetric(image, torch_padding)
1110
1111

    new_height, new_width = image.shape[-2:]
1112

1113
    return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
1114
1115


1116
# TODO: This should be removed once torch_pad supports non-scalar padding values
1117
def _pad_with_vector_fill(
1118
    image: torch.Tensor,
1119
    torch_padding: List[int],
1120
    fill: List[float],
1121
    padding_mode: str,
1122
1123
1124
1125
) -> torch.Tensor:
    if padding_mode != "constant":
        raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")

1126
1127
    output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
    left, right, top, bottom = torch_padding
1128
    fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140

    if top > 0:
        output[..., :top, :] = fill
    if left > 0:
        output[..., :, :left] = fill
    if bottom > 0:
        output[..., -bottom:, :] = fill
    if right > 0:
        output[..., :, -right:] = fill
    return output


1141
1142
1143
pad_image_pil = _FP.pad


1144
1145
def pad_mask(
    mask: torch.Tensor,
1146
1147
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1148
1149
    padding_mode: str = "constant",
) -> torch.Tensor:
1150
1151
1152
    if fill is None:
        fill = 0

1153
    if isinstance(fill, (tuple, list)):
1154
1155
        raise ValueError("Non-scalar fill value is not supported")

1156
1157
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1158
1159
1160
1161
        needs_squeeze = True
    else:
        needs_squeeze = False

1162
    output = pad_image_tensor(mask, padding=padding, fill=fill, padding_mode=padding_mode)
1163
1164
1165
1166
1167

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1168
1169


1170
def pad_bounding_box(
vfdev's avatar
vfdev committed
1171
    bounding_box: torch.Tensor,
1172
    format: datapoints.BoundingBoxFormat,
1173
    spatial_size: Tuple[int, int],
1174
    padding: List[int],
vfdev's avatar
vfdev committed
1175
    padding_mode: str = "constant",
1176
) -> Tuple[torch.Tensor, Tuple[int, int]]:
vfdev's avatar
vfdev committed
1177
1178
1179
1180
    if padding_mode not in ["constant"]:
        # TODO: add support of other padding modes
        raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")

1181
    left, right, top, bottom = _parse_pad_padding(padding)
1182

1183
    if format == datapoints.BoundingBoxFormat.XYXY:
1184
1185
1186
1187
        pad = [left, top, left, top]
    else:
        pad = [left, top, 0, 0]
    bounding_box = bounding_box + torch.tensor(pad, dtype=bounding_box.dtype, device=bounding_box.device)
1188

1189
    height, width = spatial_size
1190
1191
    height += top + bottom
    width += left + right
1192
    spatial_size = (height, width)
1193

1194
    return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size
1195
1196


1197
1198
def pad_video(
    video: torch.Tensor,
1199
1200
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1201
1202
1203
1204
1205
    padding_mode: str = "constant",
) -> torch.Tensor:
    return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode)


1206
def pad(
Philip Meier's avatar
Philip Meier committed
1207
    inpt: datapoints._InputTypeJIT,
1208
1209
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1210
    padding_mode: str = "constant",
Philip Meier's avatar
Philip Meier committed
1211
) -> datapoints._InputTypeJIT:
1212
1213
1214
    if not torch.jit.is_scripting():
        _log_api_usage_once(pad)

1215
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1216
1217
        return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)

1218
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1219
        return inpt.pad(padding, fill=fill, padding_mode=padding_mode)
1220
    elif isinstance(inpt, PIL.Image.Image):
1221
        return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
1222
1223
    else:
        raise TypeError(
1224
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1225
1226
            f"but got {type(inpt)} instead."
        )
1227
1228


1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
    h, w = image.shape[-2:]

    right = left + width
    bottom = top + height

    if left < 0 or top < 0 or right > w or bottom > h:
        image = image[..., max(top, 0) : bottom, max(left, 0) : right]
        torch_padding = [
            max(min(right, 0) - left, 0),
            max(right - max(w, left), 0),
            max(min(bottom, 0) - top, 0),
            max(bottom - max(h, top), 0),
        ]
        return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
    return image[..., top:bottom, left:right]


1247
1248
1249
crop_image_pil = _FP.crop


1250
1251
def crop_bounding_box(
    bounding_box: torch.Tensor,
1252
    format: datapoints.BoundingBoxFormat,
1253
1254
    top: int,
    left: int,
1255
1256
1257
    height: int,
    width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1258

1259
    # Crop or implicit pad if left and/or top have negative values:
1260
    if format == datapoints.BoundingBoxFormat.XYXY:
1261
        sub = [left, top, left, top]
1262
    else:
1263
1264
1265
        sub = [left, top, 0, 0]

    bounding_box = bounding_box - torch.tensor(sub, dtype=bounding_box.dtype, device=bounding_box.device)
1266
    spatial_size = (height, width)
1267

1268
    return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size
1269
1270


1271
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = crop_image_tensor(mask, top, left, height, width)

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1284
1285


1286
1287
1288
1289
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
    return crop_image_tensor(video, top, left, height, width)


Philip Meier's avatar
Philip Meier committed
1290
def crop(inpt: datapoints._InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints._InputTypeJIT:
1291
1292
1293
    if not torch.jit.is_scripting():
        _log_api_usage_once(crop)

1294
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1295
        return crop_image_tensor(inpt, top, left, height, width)
1296
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1297
        return inpt.crop(top, left, height, width)
1298
    elif isinstance(inpt, PIL.Image.Image):
1299
        return crop_image_pil(inpt, top, left, height, width)
1300
1301
    else:
        raise TypeError(
1302
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1303
1304
            f"but got {type(inpt)} instead."
        )
1305
1306


1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
    # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
    # src/libImaging/Geometry.c#L394

    #
    # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #
    theta1 = torch.tensor(
        [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
    )
    theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)

    d = 0.5
    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
1322
    x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype)
1323
    base_grid[..., 0].copy_(x_grid)
1324
    y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1)
1325
1326
1327
1328
    base_grid[..., 1].copy_(y_grid)
    base_grid[..., 2].fill_(1)

    rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device))
1329
1330
1331
    shape = (1, oh * ow, 3)
    output_grid1 = base_grid.view(shape).bmm(rescaled_theta1)
    output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2))
1332
1333
1334
1335
1336

    output_grid = output_grid1.div_(output_grid2).sub_(1.0)
    return output_grid.view(1, oh, ow, 2)


1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
def _perspective_coefficients(
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]],
) -> List[float]:
    if coefficients is not None:
        if startpoints is not None and endpoints is not None:
            raise ValueError("The startpoints/endpoints and the coefficients shouldn't be defined concurrently.")
        elif len(coefficients) != 8:
            raise ValueError("Argument coefficients should have 8 float values")
        return coefficients
    elif startpoints is not None and endpoints is not None:
        return _get_perspective_coeffs(startpoints, endpoints)
    else:
        raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.")


1354
def perspective_image_tensor(
1355
    image: torch.Tensor,
1356
1357
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1358
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1359
    fill: datapoints._FillTypeJIT = None,
1360
    coefficients: Optional[List[float]] = None,
1361
) -> torch.Tensor:
1362
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1363
1364
    interpolation = _check_interpolation(interpolation)

1365
1366
1367
1368
    if image.numel() == 0:
        return image

    shape = image.shape
1369
    ndim = image.ndim
1370

1371
    if ndim > 4:
1372
        image = image.reshape((-1,) + shape[-3:])
1373
        needs_unsquash = True
1374
1375
1376
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
1377
1378
1379
    else:
        needs_unsquash = False

1380
    _assert_grid_transform_inputs(
1381
1382
1383
1384
1385
1386
1387
1388
        image,
        matrix=None,
        interpolation=interpolation.value,
        fill=fill,
        supported_interpolation_modes=["nearest", "bilinear"],
        coeffs=perspective_coeffs,
    )

1389
    oh, ow = shape[-2:]
1390
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1391
    grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
1392
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1393
1394

    if needs_unsquash:
1395
        output = output.reshape(shape)
1396
1397

    return output
1398
1399


1400
@torch.jit.unused
1401
def perspective_image_pil(
1402
    image: PIL.Image.Image,
1403
1404
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1405
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BICUBIC,
Philip Meier's avatar
Philip Meier committed
1406
    fill: datapoints._FillTypeJIT = None,
1407
    coefficients: Optional[List[float]] = None,
1408
) -> PIL.Image.Image:
1409
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1410
    interpolation = _check_interpolation(interpolation)
1411
    return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
1412
1413


1414
1415
def perspective_bounding_box(
    bounding_box: torch.Tensor,
1416
    format: datapoints.BoundingBoxFormat,
1417
    spatial_size: Tuple[int, int],
1418
1419
1420
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]] = None,
1421
) -> torch.Tensor:
1422
1423
1424
    if bounding_box.numel() == 0:
        return bounding_box

1425
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1426
1427

    original_shape = bounding_box.shape
1428
    # TODO: first cast to float if bbox is int64 before convert_format_bounding_box
1429
    bounding_box = (
1430
        convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
1431
    ).reshape(-1, 4)
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465

    dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
    device = bounding_box.device

    # perspective_coeffs are computed as endpoint -> start point
    # We have to invert perspective_coeffs for bboxes:
    # (x, y) - end point and (x_out, y_out) - start point
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # and we would like to get:
    # x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # and compute inv_coeffs in terms of coeffs

    denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3]
    if denom == 0:
        raise RuntimeError(
            f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. "
            f"Denominator is zero, denom={denom}"
        )

    inv_coeffs = [
        (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom,
        (-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom,
        (perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom,
        (-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom,
        (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom,
        (-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom,
        (-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom,
        (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
    ]

1466
1467
    theta1 = torch.tensor(
        [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
1468
1469
1470
1471
        dtype=dtype,
        device=device,
    )

1472
1473
1474
1475
    theta2 = torch.tensor(
        [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device
    )

1476
1477
1478
1479
    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
1480
    points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1481
1482
1483
1484
1485
    points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
    # 2) Now let's transform the points using perspective matrices
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)

1486
1487
    numer_points = torch.matmul(points, theta1.T)
    denom_points = torch.matmul(points, theta2.T)
1488
    transformed_points = numer_points.div_(denom_points)
1489
1490
1491

    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
1492
    transformed_points = transformed_points.reshape(-1, 4, 2)
1493
1494
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)

1495
1496
1497
1498
1499
    out_bboxes = clamp_bounding_box(
        torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype),
        format=datapoints.BoundingBoxFormat.XYXY,
        spatial_size=spatial_size,
    )
1500
1501
1502

    # out_bboxes should be of shape [N boxes, 4]

1503
    return convert_format_bounding_box(
1504
        out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
1505
    ).reshape(original_shape)
1506
1507


1508
1509
def perspective_mask(
    mask: torch.Tensor,
1510
1511
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
Philip Meier's avatar
Philip Meier committed
1512
    fill: datapoints._FillTypeJIT = None,
1513
    coefficients: Optional[List[float]] = None,
1514
) -> torch.Tensor:
1515
1516
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1517
1518
1519
1520
1521
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = perspective_image_tensor(
1522
        mask, startpoints, endpoints, interpolation=InterpolationMode.NEAREST, fill=fill, coefficients=coefficients
1523
    )
1524

1525
1526
1527
1528
1529
    if needs_squeeze:
        output = output.squeeze(0)

    return output

1530

1531
1532
def perspective_video(
    video: torch.Tensor,
1533
1534
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1535
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1536
    fill: datapoints._FillTypeJIT = None,
1537
    coefficients: Optional[List[float]] = None,
1538
) -> torch.Tensor:
1539
1540
1541
    return perspective_image_tensor(
        video, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
    )
1542
1543


1544
def perspective(
Philip Meier's avatar
Philip Meier committed
1545
    inpt: datapoints._InputTypeJIT,
1546
1547
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1548
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1549
    fill: datapoints._FillTypeJIT = None,
1550
    coefficients: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
1551
) -> datapoints._InputTypeJIT:
1552
1553
    if not torch.jit.is_scripting():
        _log_api_usage_once(perspective)
1554
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1555
1556
1557
        return perspective_image_tensor(
            inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
        )
1558
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1559
1560
1561
        return inpt.perspective(
            startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
        )
1562
    elif isinstance(inpt, PIL.Image.Image):
1563
1564
1565
        return perspective_image_pil(
            inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
        )
1566
1567
    else:
        raise TypeError(
1568
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1569
1570
            f"but got {type(inpt)} instead."
        )
1571
1572


1573
def elastic_image_tensor(
1574
    image: torch.Tensor,
1575
    displacement: torch.Tensor,
1576
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1577
    fill: datapoints._FillTypeJIT = None,
1578
) -> torch.Tensor:
1579
1580
    interpolation = _check_interpolation(interpolation)

1581
1582
1583
1584
    if image.numel() == 0:
        return image

    shape = image.shape
1585
    ndim = image.ndim
1586

1587
    device = image.device
1588
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1589
1590
1591
1592
1593
1594
1595

    # Patch: elastic transform should support (cpu,f16) input
    is_cpu_half = device.type == "cpu" and dtype == torch.float16
    if is_cpu_half:
        image = image.to(torch.float32)
        dtype = torch.float32

1596
1597
1598
    # We are aware that if input image dtype is uint8 and displacement is float64 then
    # displacement will be casted to float32 and all computations will be done with float32
    # We can fix this later if needed
1599

1600
1601
1602
1603
    expected_shape = (1,) + shape[-2:] + (2,)
    if expected_shape != displacement.shape:
        raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")

1604
    if ndim > 4:
1605
        image = image.reshape((-1,) + shape[-3:])
1606
        needs_unsquash = True
1607
1608
1609
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
1610
1611
1612
    else:
        needs_unsquash = False

1613
1614
    if displacement.dtype != dtype or displacement.device != device:
        displacement = displacement.to(dtype=dtype, device=device)
1615

1616
1617
1618
    image_height, image_width = shape[-2:]
    grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement)
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1619
1620

    if needs_unsquash:
1621
        output = output.reshape(shape)
1622

1623
1624
1625
    if is_cpu_half:
        output = output.to(torch.float16)

1626
    return output
1627
1628


1629
@torch.jit.unused
1630
def elastic_image_pil(
1631
    image: PIL.Image.Image,
1632
    displacement: torch.Tensor,
1633
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1634
    fill: datapoints._FillTypeJIT = None,
1635
) -> PIL.Image.Image:
1636
    t_img = pil_to_tensor(image)
1637
    output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill)
1638
    return to_pil_image(output, mode=image.mode)
1639
1640


1641
def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
1642
    sy, sx = size
1643
1644
    base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype)
    x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype)
1645
1646
    base_grid[..., 0].copy_(x_grid)

1647
    y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1)
1648
1649
1650
1651
1652
    base_grid[..., 1].copy_(y_grid)

    return base_grid


1653
1654
def elastic_bounding_box(
    bounding_box: torch.Tensor,
1655
    format: datapoints.BoundingBoxFormat,
1656
    spatial_size: Tuple[int, int],
1657
1658
    displacement: torch.Tensor,
) -> torch.Tensor:
1659
1660
1661
    if bounding_box.numel() == 0:
        return bounding_box

1662
    # TODO: add in docstring about approximation we are doing for grid inversion
1663
1664
1665
1666
1667
    device = bounding_box.device
    dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32

    if displacement.dtype != dtype or displacement.device != device:
        displacement = displacement.to(dtype=dtype, device=device)
1668
1669

    original_shape = bounding_box.shape
1670
    # TODO: first cast to float if bbox is int64 before convert_format_bounding_box
1671
    bounding_box = (
1672
        convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
1673
    ).reshape(-1, 4)
1674

1675
    id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype)
1676
1677
    # We construct an approximation of inverse grid as inv_grid = id_grid - displacement
    # This is not an exact inverse of the grid
1678
    inv_grid = id_grid.sub_(displacement)
1679
1680

    # Get points from bboxes
1681
    points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1682
1683
1684
1685
1686
    if points.is_floating_point():
        points = points.ceil_()
    index_xy = points.to(dtype=torch.long)
    index_x, index_y = index_xy[:, 0], index_xy[:, 1]

1687
    # Transform points:
1688
    t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype)
1689
    transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
1690

1691
    transformed_points = transformed_points.reshape(-1, 4, 2)
1692
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
1693
1694
1695
1696
1697
    out_bboxes = clamp_bounding_box(
        torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype),
        format=datapoints.BoundingBoxFormat.XYXY,
        spatial_size=spatial_size,
    )
1698

1699
    return convert_format_bounding_box(
1700
        out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
1701
    ).reshape(original_shape)
1702
1703


1704
1705
1706
def elastic_mask(
    mask: torch.Tensor,
    displacement: torch.Tensor,
Philip Meier's avatar
Philip Meier committed
1707
    fill: datapoints._FillTypeJIT = None,
1708
) -> torch.Tensor:
1709
1710
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1711
1712
1713
1714
        needs_squeeze = True
    else:
        needs_squeeze = False

1715
    output = elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill)
1716
1717
1718
1719
1720

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1721
1722


1723
1724
1725
def elastic_video(
    video: torch.Tensor,
    displacement: torch.Tensor,
1726
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1727
    fill: datapoints._FillTypeJIT = None,
1728
) -> torch.Tensor:
1729
    return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
1730
1731


1732
def elastic(
Philip Meier's avatar
Philip Meier committed
1733
    inpt: datapoints._InputTypeJIT,
1734
    displacement: torch.Tensor,
1735
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1736
1737
    fill: datapoints._FillTypeJIT = None,
) -> datapoints._InputTypeJIT:
1738
1739
1740
    if not torch.jit.is_scripting():
        _log_api_usage_once(elastic)

1741
1742
1743
    if not isinstance(displacement, torch.Tensor):
        raise TypeError("Argument displacement should be a Tensor")

1744
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1745
        return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
1746
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1747
        return inpt.elastic(displacement, interpolation=interpolation, fill=fill)
1748
    elif isinstance(inpt, PIL.Image.Image):
1749
        return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill)
1750
1751
    else:
        raise TypeError(
1752
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1753
1754
            f"but got {type(inpt)} instead."
        )
1755
1756
1757
1758
1759


elastic_transform = elastic


1760
1761
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
    if isinstance(output_size, numbers.Number):
1762
1763
        s = int(output_size)
        return [s, s]
1764
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
1765
        return [output_size[0], output_size[0]]
1766
1767
    else:
        return list(output_size)
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786


def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]:
    return [
        (crop_width - image_width) // 2 if crop_width > image_width else 0,
        (crop_height - image_height) // 2 if crop_height > image_height else 0,
        (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
        (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
    ]


def _center_crop_compute_crop_anchor(
    crop_height: int, crop_width: int, image_height: int, image_width: int
) -> Tuple[int, int]:
    crop_top = int(round((image_height - crop_height) / 2.0))
    crop_left = int(round((image_width - crop_width) / 2.0))
    return crop_top, crop_left


1787
def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
1788
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1789
1790
1791
1792
    shape = image.shape
    if image.numel() == 0:
        return image.reshape(shape[:-2] + (crop_height, crop_width))
    image_height, image_width = shape[-2:]
1793
1794
1795

    if crop_height > image_height or crop_width > image_width:
        padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
1796
        image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0)
1797

1798
        image_height, image_width = image.shape[-2:]
1799
        if crop_width == image_width and crop_height == image_height:
1800
            return image
1801
1802

    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
1803
    return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]
1804
1805


1806
@torch.jit.unused
1807
def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
1808
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1809
    image_height, image_width = get_spatial_size_image_pil(image)
1810
1811
1812

    if crop_height > image_height or crop_width > image_width:
        padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
1813
        image = pad_image_pil(image, padding_ltrb, fill=0)
1814

1815
        image_height, image_width = get_spatial_size_image_pil(image)
1816
        if crop_width == image_width and crop_height == image_height:
1817
            return image
1818
1819

    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
1820
    return crop_image_pil(image, crop_top, crop_left, crop_height, crop_width)
1821
1822


1823
1824
def center_crop_bounding_box(
    bounding_box: torch.Tensor,
1825
    format: datapoints.BoundingBoxFormat,
1826
    spatial_size: Tuple[int, int],
1827
    output_size: List[int],
1828
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1829
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1830
    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *spatial_size)
1831
    return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width)
1832
1833


1834
1835
1836
def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1837
1838
1839
1840
        needs_squeeze = True
    else:
        needs_squeeze = False

1841
    output = center_crop_image_tensor(image=mask, output_size=output_size)
1842
1843
1844
1845
1846

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1847
1848


1849
1850
1851
1852
def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor:
    return center_crop_image_tensor(video, output_size)


Philip Meier's avatar
Philip Meier committed
1853
def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datapoints._InputTypeJIT:
1854
1855
1856
    if not torch.jit.is_scripting():
        _log_api_usage_once(center_crop)

1857
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1858
        return center_crop_image_tensor(inpt, output_size)
1859
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1860
        return inpt.center_crop(output_size)
1861
    elif isinstance(inpt, PIL.Image.Image):
1862
        return center_crop_image_pil(inpt, output_size)
1863
1864
    else:
        raise TypeError(
1865
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1866
1867
            f"but got {type(inpt)} instead."
        )
1868
1869


1870
def resized_crop_image_tensor(
1871
    image: torch.Tensor,
1872
1873
1874
1875
1876
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1877
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1878
    antialias: Optional[Union[str, bool]] = "warn",
1879
) -> torch.Tensor:
1880
1881
    image = crop_image_tensor(image, top, left, height, width)
    return resize_image_tensor(image, size, interpolation=interpolation, antialias=antialias)
1882
1883


1884
@torch.jit.unused
1885
def resized_crop_image_pil(
1886
    image: PIL.Image.Image,
1887
1888
1889
1890
1891
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1892
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1893
) -> PIL.Image.Image:
1894
1895
    image = crop_image_pil(image, top, left, height, width)
    return resize_image_pil(image, size, interpolation=interpolation)
1896
1897


1898
1899
def resized_crop_bounding_box(
    bounding_box: torch.Tensor,
1900
    format: datapoints.BoundingBoxFormat,
1901
1902
1903
1904
1905
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1906
1907
) -> Tuple[torch.Tensor, Tuple[int, int]]:
    bounding_box, _ = crop_bounding_box(bounding_box, format, top, left, height, width)
1908
    return resize_bounding_box(bounding_box, spatial_size=(height, width), size=size)
1909
1910


1911
def resized_crop_mask(
1912
1913
1914
1915
1916
1917
1918
    mask: torch.Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
) -> torch.Tensor:
1919
1920
    mask = crop_mask(mask, top, left, height, width)
    return resize_mask(mask, size)
1921
1922


1923
1924
1925
1926
1927
1928
1929
def resized_crop_video(
    video: torch.Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1930
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1931
    antialias: Optional[Union[str, bool]] = "warn",
1932
1933
1934
1935
1936
1937
) -> torch.Tensor:
    return resized_crop_image_tensor(
        video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
    )


1938
def resized_crop(
Philip Meier's avatar
Philip Meier committed
1939
    inpt: datapoints._InputTypeJIT,
1940
1941
1942
1943
1944
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1945
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1946
    antialias: Optional[Union[str, bool]] = "warn",
Philip Meier's avatar
Philip Meier committed
1947
) -> datapoints._InputTypeJIT:
1948
1949
1950
    if not torch.jit.is_scripting():
        _log_api_usage_once(resized_crop)

1951
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1952
1953
1954
        return resized_crop_image_tensor(
            inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
        )
1955
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1956
        return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
1957
    elif isinstance(inpt, PIL.Image.Image):
1958
        return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation)
1959
1960
    else:
        raise TypeError(
1961
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1962
1963
            f"but got {type(inpt)} instead."
        )
1964
1965


1966
1967
def _parse_five_crop_size(size: List[int]) -> List[int]:
    if isinstance(size, numbers.Number):
1968
1969
        s = int(size)
        size = [s, s]
1970
    elif isinstance(size, (tuple, list)) and len(size) == 1:
1971
1972
        s = size[0]
        size = [s, s]
1973
1974
1975
1976
1977
1978
1979
1980

    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")

    return size


def five_crop_image_tensor(
1981
    image: torch.Tensor, size: List[int]
1982
1983
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    crop_height, crop_width = _parse_five_crop_size(size)
1984
    image_height, image_width = image.shape[-2:]
1985
1986

    if crop_width > image_width or crop_height > image_height:
1987
        raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
1988

1989
1990
1991
1992
1993
    tl = crop_image_tensor(image, 0, 0, crop_height, crop_width)
    tr = crop_image_tensor(image, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop_image_tensor(image, image_height - crop_height, 0, crop_height, crop_width)
    br = crop_image_tensor(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
    center = center_crop_image_tensor(image, [crop_height, crop_width])
1994
1995
1996
1997

    return tl, tr, bl, br, center


1998
@torch.jit.unused
1999
def five_crop_image_pil(
2000
    image: PIL.Image.Image, size: List[int]
2001
2002
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
    crop_height, crop_width = _parse_five_crop_size(size)
2003
    image_height, image_width = get_spatial_size_image_pil(image)
2004
2005

    if crop_width > image_width or crop_height > image_height:
2006
        raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
2007

2008
2009
2010
2011
2012
    tl = crop_image_pil(image, 0, 0, crop_height, crop_width)
    tr = crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width)
    br = crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
    center = center_crop_image_pil(image, [crop_height, crop_width])
2013
2014
2015
2016

    return tl, tr, bl, br, center


2017
2018
2019
2020
2021
2022
def five_crop_video(
    video: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    return five_crop_image_tensor(video, size)


Philip Meier's avatar
Philip Meier committed
2023
ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]
2024
2025


2026
def five_crop(
2027
2028
    inpt: ImageOrVideoTypeJIT, size: List[int]
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
2029
2030
2031
    if not torch.jit.is_scripting():
        _log_api_usage_once(five_crop)

2032
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
2033
        return five_crop_image_tensor(inpt, size)
2034
    elif isinstance(inpt, datapoints.Image):
2035
        output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size)
2036
2037
        return tuple(datapoints.Image.wrap_like(inpt, item) for item in output)  # type: ignore[return-value]
    elif isinstance(inpt, datapoints.Video):
2038
        output = five_crop_video(inpt.as_subclass(torch.Tensor), size)
2039
        return tuple(datapoints.Video.wrap_like(inpt, item) for item in output)  # type: ignore[return-value]
2040
    elif isinstance(inpt, PIL.Image.Image):
2041
        return five_crop_image_pil(inpt, size)
2042
2043
    else:
        raise TypeError(
2044
            f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
2045
2046
            f"but got {type(inpt)} instead."
        )
2047
2048


Philip Meier's avatar
Philip Meier committed
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
def ten_crop_image_tensor(
    image: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
    non_flipped = five_crop_image_tensor(image, size)
2064
2065

    if vertical_flip:
2066
        image = vertical_flip_image_tensor(image)
2067
    else:
2068
        image = horizontal_flip_image_tensor(image)
2069

Philip Meier's avatar
Philip Meier committed
2070
    flipped = five_crop_image_tensor(image, size)
2071

Philip Meier's avatar
Philip Meier committed
2072
    return non_flipped + flipped
2073
2074


2075
@torch.jit.unused
Philip Meier's avatar
Philip Meier committed
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
def ten_crop_image_pil(
    image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
) -> Tuple[
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
]:
    non_flipped = five_crop_image_pil(image, size)
2091
2092

    if vertical_flip:
2093
        image = vertical_flip_image_pil(image)
2094
    else:
2095
        image = horizontal_flip_image_pil(image)
2096

Philip Meier's avatar
Philip Meier committed
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
    flipped = five_crop_image_pil(image, size)

    return non_flipped + flipped


def ten_crop_video(
    video: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
2116
2117
2118
2119
    return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip)


def ten_crop(
Philip Meier's avatar
Philip Meier committed
2120
    inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], size: List[int], vertical_flip: bool = False
Philip Meier's avatar
Philip Meier committed
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
) -> Tuple[
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
]:
2133
2134
2135
    if not torch.jit.is_scripting():
        _log_api_usage_once(ten_crop)

2136
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
2137
        return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
2138
    elif isinstance(inpt, datapoints.Image):
2139
        output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
2140
        return tuple(datapoints.Image.wrap_like(inpt, item) for item in output)  # type: ignore[return-value]
2141
    elif isinstance(inpt, datapoints.Video):
2142
        output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
2143
        return tuple(datapoints.Video.wrap_like(inpt, item) for item in output)  # type: ignore[return-value]
2144
    elif isinstance(inpt, PIL.Image.Image):
2145
        return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
2146
2147
    else:
        raise TypeError(
2148
            f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
2149
2150
            f"but got {type(inpt)} instead."
        )