_geometry.py 78 KB
Newer Older
1
import math
2
import numbers
3
import warnings
4
from typing import List, Optional, Sequence, Tuple, Union
5
6
7

import PIL.Image
import torch
8
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
9

10
from torchvision import datapoints
11
12
from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _pad_symmetric
13
from torchvision.transforms.functional import (
14
    _check_antialias,
15
    _compute_resized_output_size as __compute_resized_output_size,
16
    _get_perspective_coeffs,
17
    _interpolation_modes_from_int,
18
    InterpolationMode,
19
    pil_modes_mapping,
20
21
    pil_to_tensor,
    to_pil_image,
22
)
23

24
25
from torchvision.utils import _log_api_usage_once

26
from ._meta import clamp_bounding_box, convert_format_bounding_box, get_spatial_size_image_pil
27

28
29
from ._utils import is_simple_tensor

30

31
32
33
34
35
36
37
38
39
40
41
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
    if isinstance(interpolation, int):
        interpolation = _interpolation_modes_from_int(interpolation)
    elif not isinstance(interpolation, InterpolationMode):
        raise ValueError(
            f"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, "
            f"but got {interpolation}."
        )
    return interpolation


42
43
44
45
def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
    return image.flip(-1)


46
47
def horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
    return _FP.hflip(image)
48
49


50
51
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
    return horizontal_flip_image_tensor(mask)
52
53


54
def horizontal_flip_bounding_box(
55
    bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int]
56
57
58
) -> torch.Tensor:
    shape = bounding_box.shape

59
    bounding_box = bounding_box.clone().reshape(-1, 4)
60

61
    if format == datapoints.BoundingBoxFormat.XYXY:
62
        bounding_box[:, [2, 0]] = bounding_box[:, [0, 2]].sub_(spatial_size[1]).neg_()
63
    elif format == datapoints.BoundingBoxFormat.XYWH:
64
        bounding_box[:, 0].add_(bounding_box[:, 2]).sub_(spatial_size[1]).neg_()
65
    else:  # format == datapoints.BoundingBoxFormat.CXCYWH:
66
        bounding_box[:, 0].sub_(spatial_size[1]).neg_()
67

68
    return bounding_box.reshape(shape)
69
70


71
72
73
74
def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
    return horizontal_flip_image_tensor(video)


Philip Meier's avatar
Philip Meier committed
75
def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
76
77
78
    if not torch.jit.is_scripting():
        _log_api_usage_once(horizontal_flip)

79
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
80
        return horizontal_flip_image_tensor(inpt)
81
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
82
        return inpt.horizontal_flip()
83
    elif isinstance(inpt, PIL.Image.Image):
84
        return horizontal_flip_image_pil(inpt)
85
86
    else:
        raise TypeError(
87
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
88
89
            f"but got {type(inpt)} instead."
        )
90
91


92
93
94
95
def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
    return image.flip(-2)


96
97
98
vertical_flip_image_pil = _FP.vflip


99
100
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
    return vertical_flip_image_tensor(mask)
101
102
103


def vertical_flip_bounding_box(
104
    bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int]
105
106
107
) -> torch.Tensor:
    shape = bounding_box.shape

108
    bounding_box = bounding_box.clone().reshape(-1, 4)
109

110
    if format == datapoints.BoundingBoxFormat.XYXY:
111
        bounding_box[:, [1, 3]] = bounding_box[:, [3, 1]].sub_(spatial_size[0]).neg_()
112
    elif format == datapoints.BoundingBoxFormat.XYWH:
113
        bounding_box[:, 1].add_(bounding_box[:, 3]).sub_(spatial_size[0]).neg_()
114
    else:  # format == datapoints.BoundingBoxFormat.CXCYWH:
115
        bounding_box[:, 1].sub_(spatial_size[0]).neg_()
116

117
    return bounding_box.reshape(shape)
118
119


120
121
122
123
def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
    return vertical_flip_image_tensor(video)


Philip Meier's avatar
Philip Meier committed
124
def vertical_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
125
126
127
    if not torch.jit.is_scripting():
        _log_api_usage_once(vertical_flip)

128
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
129
        return vertical_flip_image_tensor(inpt)
130
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
131
        return inpt.vertical_flip()
132
    elif isinstance(inpt, PIL.Image.Image):
133
        return vertical_flip_image_pil(inpt)
134
135
    else:
        raise TypeError(
136
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
137
138
            f"but got {type(inpt)} instead."
        )
139
140


141
142
143
144
145
146
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
hflip = horizontal_flip
vflip = vertical_flip


147
def _compute_resized_output_size(
148
    spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
149
150
151
) -> List[int]:
    if isinstance(size, int):
        size = [size]
152
153
154
155
156
    elif max_size is not None and len(size) != 1:
        raise ValueError(
            "max_size should only be passed if size specifies the length of the smaller edge, "
            "i.e. size should be an int or a sequence of length 1 in torchscript mode."
        )
157
    return __compute_resized_output_size(spatial_size, size=size, max_size=max_size)
158
159


160
161
162
def resize_image_tensor(
    image: torch.Tensor,
    size: List[int],
163
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
164
    max_size: Optional[int] = None,
165
    antialias: Optional[Union[str, bool]] = "warn",
166
) -> torch.Tensor:
167
    interpolation = _check_interpolation(interpolation)
168
169
    antialias = _check_antialias(img=image, antialias=antialias, interpolation=interpolation)
    assert not isinstance(antialias, str)
170
    antialias = False if antialias is None else antialias
171
172
173
    align_corners: Optional[bool] = None
    if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC:
        align_corners = False
174
175
176
177
    else:
        # The default of antialias should be True from 0.17, so we don't warn or
        # error if other interpolation modes are used. This is documented.
        antialias = False
178

179
    shape = image.shape
180
    numel = image.numel()
181
    num_channels, old_height, old_width = shape[-3:]
vfdev's avatar
vfdev committed
182
    new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
183

184
185
    if (new_height, new_width) == (old_height, old_width):
        return image
186
    elif numel > 0:
187
        image = image.reshape(-1, num_channels, old_height, old_width)
188

189
        dtype = image.dtype
190
191
192
193
        acceptable_dtypes = [torch.float32, torch.float64]
        if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
            # uint8 dtype can be included for cpu and cuda input if nearest mode
            acceptable_dtypes.append(torch.uint8)
194
195
196
197
198
199
200
        elif image.device.type == "cpu":
            # uint8 dtype support for bilinear and bicubic is limited to cpu and
            # according to our benchmarks, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilinear
            if (interpolation == InterpolationMode.BILINEAR and "AVX2" in torch.backends.cpu.get_cpu_capability()) or (
                interpolation == InterpolationMode.BICUBIC
            ):
                acceptable_dtypes.append(torch.uint8)
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216

        strides = image.stride()
        if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
            # There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as
            # contiguous even though the input is un-ambiguously channels_last (https://github.com/pytorch/pytorch/issues/68430).
            # In particular this happens for the typical torchvision use-case of single CHW images where we fake the batch dim
            # to become 1CHW. Below, we restride those tensors to trick torch core into properly allocating the output as
            # channels_last, thus preserving the memory format of the input. This is not just for format consistency:
            # for uint8 bilinear images, this also avoids an extra copy (re-packing) of the output and saves time.
            # TODO: when https://github.com/pytorch/pytorch/issues/68430 is fixed (possibly by https://github.com/pytorch/pytorch/pull/100373),
            # we should be able to remove this hack.
            new_strides = list(strides)
            new_strides[0] = numel
            image = image.as_strided((1, num_channels, old_height, old_width), new_strides)

        need_cast = dtype not in acceptable_dtypes
217
218
219
220
        if need_cast:
            image = image.to(dtype=torch.float32)

        image = interpolate(
221
222
            image,
            size=[new_height, new_width],
223
224
            mode=interpolation.value,
            align_corners=align_corners,
225
226
            antialias=antialias,
        )
227

228
229
        if need_cast:
            if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
230
                # This path is hit on non-AVX archs, or on GPU.
231
                image = image.clamp_(min=0, max=255)
232
233
234
            if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
                image = image.round_()
            image = image.to(dtype=dtype)
235

236
    return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
237
238


239
@torch.jit.unused
240
def resize_image_pil(
241
    image: PIL.Image.Image,
242
    size: Union[Sequence[int], int],
243
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
244
245
    max_size: Optional[int] = None,
) -> PIL.Image.Image:
246
247
248
249
250
251
252
    old_height, old_width = image.height, image.width
    new_height, new_width = _compute_resized_output_size(
        (old_height, old_width),
        size=size,  # type: ignore[arg-type]
        max_size=max_size,
    )

253
    interpolation = _check_interpolation(interpolation)
254
255
256
257
258

    if (new_height, new_width) == (old_height, old_width):
        return image

    return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation])
259
260


261
262
263
def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
264
265
266
267
        needs_squeeze = True
    else:
        needs_squeeze = False

268
    output = resize_image_tensor(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
269
270
271
272
273

    if needs_squeeze:
        output = output.squeeze(0)

    return output
274
275


276
def resize_bounding_box(
277
    bounding_box: torch.Tensor, spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
278
) -> Tuple[torch.Tensor, Tuple[int, int]]:
279
280
    old_height, old_width = spatial_size
    new_height, new_width = _compute_resized_output_size(spatial_size, size=size, max_size=max_size)
281
282
283
284

    if (new_height, new_width) == (old_height, old_width):
        return bounding_box, spatial_size

285
286
287
    w_ratio = new_width / old_width
    h_ratio = new_height / old_height
    ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_box.device)
288
    return (
289
        bounding_box.mul(ratios).to(bounding_box.dtype),
290
291
        (new_height, new_width),
    )
292
293


294
295
296
def resize_video(
    video: torch.Tensor,
    size: List[int],
297
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
298
    max_size: Optional[int] = None,
299
    antialias: Optional[Union[str, bool]] = "warn",
300
301
302
303
) -> torch.Tensor:
    return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)


304
def resize(
Philip Meier's avatar
Philip Meier committed
305
    inpt: datapoints._InputTypeJIT,
306
    size: List[int],
307
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
308
    max_size: Optional[int] = None,
309
    antialias: Optional[Union[str, bool]] = "warn",
Philip Meier's avatar
Philip Meier committed
310
) -> datapoints._InputTypeJIT:
311
312
    if not torch.jit.is_scripting():
        _log_api_usage_once(resize)
313
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
314
        return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
315
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
316
        return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
317
    elif isinstance(inpt, PIL.Image.Image):
318
        if antialias is False:
319
320
            warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
        return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size)
321
322
    else:
        raise TypeError(
323
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
324
325
            f"but got {type(inpt)} instead."
        )
326
327


328
def _affine_parse_args(
329
    angle: Union[int, float],
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
    translate: List[float],
    scale: float,
    shear: List[float],
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    center: Optional[List[float]] = None,
) -> Tuple[float, List[float], List[float], Optional[List[float]]]:
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if not isinstance(translate, (list, tuple)):
        raise TypeError("Argument translate should be a sequence")

    if len(translate) != 2:
        raise ValueError("Argument translate should be a sequence of length 2")

    if scale <= 0.0:
        raise ValueError("Argument scale should be positive")

    if not isinstance(shear, (numbers.Number, (list, tuple))):
        raise TypeError("Shear should be either a single value or a sequence of two values")

    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")

    if isinstance(angle, int):
        angle = float(angle)

    if isinstance(translate, tuple):
        translate = list(translate)

    if isinstance(shear, numbers.Number):
        shear = [shear, 0.0]

    if isinstance(shear, tuple):
        shear = list(shear)

    if len(shear) == 1:
        shear = [shear[0], shear[0]]

    if len(shear) != 2:
        raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")

372
373
374
375
376
    if center is not None:
        if not isinstance(center, (list, tuple)):
            raise TypeError("Argument center should be a sequence")
        else:
            center = [float(c) for c in center]
377
378
379
380

    return angle, translate, shear, center


381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
def _get_inverse_affine_matrix(
    center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
) -> List[float]:
    # Helper method to compute inverse matrix for affine transformation

    # Pillow requires inverse affine transformation matrix:
    # Affine matrix is : M = T * C * RotateScaleShear * C^-1
    #
    # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
    #       C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
    #       RotateScaleShear is rotation with scale and shear matrix
    #
    #       RotateScaleShear(a, s, (sx, sy)) =
    #       = R(a) * S(s) * SHy(sy) * SHx(sx)
    #       = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
    #         [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
    #         [ 0                    , 0                                      , 1 ]
    # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
    # SHx(s) = [1, -tan(s)] and SHy(s) = [1      , 0]
    #          [0, 1      ]              [-tan(s), 1]
    #
    # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1

    rot = math.radians(angle)
    sx = math.radians(shear[0])
    sy = math.radians(shear[1])

    cx, cy = center
    tx, ty = translate

    # Cached results
    cos_sy = math.cos(sy)
    tan_sx = math.tan(sx)
    rot_minus_sy = rot - sy
    cx_plus_tx = cx + tx
    cy_plus_ty = cy + ty

    # Rotate Scale Shear (RSS) without scaling
    a = math.cos(rot_minus_sy) / cos_sy
    b = -(a * tan_sx + math.sin(rot))
    c = math.sin(rot_minus_sy) / cos_sy
    d = math.cos(rot) - c * tan_sx

    if inverted:
        # Inverted rotation matrix with scale and shear
        # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
        matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0]
        # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
        # and then apply center translation: C * RSS^-1 * C^-1 * T^-1
        matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty
        matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty
    else:
        matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0]
        # Apply inverse of center translation: RSS * C^-1
        # and then apply translation and center : T * C * RSS * C^-1
        matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy
        matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy

    return matrix


def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
    # Inspired of PIL implementation:
    # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054

    # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
    # Points are shifted due to affine matrix torch convention about
    # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
    half_w = 0.5 * w
    half_h = 0.5 * h
    pts = torch.tensor(
        [
            [-half_w, -half_h, 1.0],
            [-half_w, half_h, 1.0],
            [half_w, half_h, 1.0],
            [half_w, -half_h, 1.0],
        ]
    )
    theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
    new_pts = torch.matmul(pts, theta.T)
    min_vals, max_vals = new_pts.aminmax(dim=0)

    # shift points to [0, w] and [0, h] interval to match PIL results
    halfs = torch.tensor((half_w, half_h))
    min_vals.add_(halfs)
    max_vals.add_(halfs)

    # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
    tol = 1e-4
    inv_tol = 1.0 / tol
    cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_()
    cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_()
    size = cmax.sub_(cmin)
    return int(size[0]), int(size[1])  # w, h


def _apply_grid_transform(
Philip Meier's avatar
Philip Meier committed
478
    img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints._FillTypeJIT
479
480
) -> torch.Tensor:

481
482
483
484
    # We are using context knowledge that grid should have float dtype
    fp = img.dtype == grid.dtype
    float_img = img if fp else img.to(grid.dtype)

485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
    shape = float_img.shape
    if shape[0] > 1:
        # Apply same grid to a batch of images
        grid = grid.expand(shape[0], -1, -1, -1)

    # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
    if fill is not None:
        mask = torch.ones((shape[0], 1, shape[2], shape[3]), dtype=float_img.dtype, device=float_img.device)
        float_img = torch.cat((float_img, mask), dim=1)

    float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False)

    # Fill with required color
    if fill is not None:
        float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
        mask = mask.expand_as(float_img)
501
        fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)]  # type: ignore[arg-type]
502
503
504
505
506
507
508
509
510
        fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
        if mode == "nearest":
            bool_mask = mask < 0.5
            float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
        else:  # 'bilinear'
            # The following is mathematically equivalent to:
            # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
            float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)

511
512
513
    img = float_img.round_().to(img.dtype) if not fp else float_img

    return img
514
515
516
517
518
519


def _assert_grid_transform_inputs(
    image: torch.Tensor,
    matrix: Optional[List[float]],
    interpolation: str,
Philip Meier's avatar
Philip Meier committed
520
    fill: datapoints._FillTypeJIT,
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
    supported_interpolation_modes: List[str],
    coeffs: Optional[List[float]] = None,
) -> None:
    if matrix is not None:
        if not isinstance(matrix, list):
            raise TypeError("Argument matrix should be a list")
        elif len(matrix) != 6:
            raise ValueError("Argument matrix should have 6 float values")

    if coeffs is not None and len(coeffs) != 8:
        raise ValueError("Argument coeffs should have 8 float values")

    if fill is not None:
        if isinstance(fill, (tuple, list)):
            length = len(fill)
            num_channels = image.shape[-3]
            if length > 1 and length != num_channels:
                raise ValueError(
                    "The number of elements in 'fill' cannot broadcast to match the number of "
                    f"channels of the image ({length} != {num_channels})"
                )
        elif not isinstance(fill, (int, float)):
            raise ValueError("Argument fill should be either int, float, tuple or list")

    if interpolation not in supported_interpolation_modes:
        raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")


def _affine_grid(
    theta: torch.Tensor,
    w: int,
    h: int,
    ow: int,
    oh: int,
) -> torch.Tensor:
    # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
    # AffineGridGenerator.cpp#L18
    # Difference with AffineGridGenerator is that:
    # 1) we normalize grid values after applying theta
    # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
    dtype = theta.dtype
    device = theta.device

    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
    x_grid = torch.linspace((1.0 - ow) * 0.5, (ow - 1.0) * 0.5, steps=ow, device=device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace((1.0 - oh) * 0.5, (oh - 1.0) * 0.5, steps=oh, device=device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
    base_grid[..., 2].fill_(1)

    rescaled_theta = theta.transpose(1, 2).div_(torch.tensor([0.5 * w, 0.5 * h], dtype=dtype, device=device))
    output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
    return output_grid.view(1, oh, ow, 2)


576
def affine_image_tensor(
577
    image: torch.Tensor,
578
    angle: Union[int, float],
579
580
581
    translate: List[float],
    scale: float,
    shear: List[float],
582
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
583
    fill: datapoints._FillTypeJIT = None,
584
585
    center: Optional[List[float]] = None,
) -> torch.Tensor:
586
587
    interpolation = _check_interpolation(interpolation)

588
589
    if image.numel() == 0:
        return image
590

591
    shape = image.shape
592
    ndim = image.ndim
593

594
595
596
597
598
599
600
601
602
603
    if ndim > 4:
        image = image.reshape((-1,) + shape[-3:])
        needs_unsquash = True
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
    else:
        needs_unsquash = False

    height, width = shape[-2:]
604
605
606
607
608
    angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

    center_f = [0.0, 0.0]
    if center is not None:
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
609
        center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
610

611
    translate_f = [float(t) for t in translate]
612
613
    matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)

614
615
    _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])

616
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
617
618
    theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
    grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
619
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
620
621
622
623
624

    if needs_unsquash:
        output = output.reshape(shape)

    return output
625
626


627
@torch.jit.unused
628
def affine_image_pil(
629
    image: PIL.Image.Image,
630
    angle: Union[int, float],
631
632
633
    translate: List[float],
    scale: float,
    shear: List[float],
634
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
635
    fill: datapoints._FillTypeJIT = None,
636
637
    center: Optional[List[float]] = None,
) -> PIL.Image.Image:
638
    interpolation = _check_interpolation(interpolation)
639
640
641
642
643
644
    angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

    # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
    # it is visually better to estimate the center without 0.5 offset
    # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
    if center is None:
645
        height, width = get_spatial_size_image_pil(image)
646
647
648
        center = [width * 0.5, height * 0.5]
    matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)

649
    return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
650
651


652
def _affine_bounding_box_with_expand(
653
    bounding_box: torch.Tensor,
654
    format: datapoints.BoundingBoxFormat,
655
    spatial_size: Tuple[int, int],
656
657
658
659
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
660
    center: Optional[List[float]] = None,
661
    expand: bool = False,
662
) -> Tuple[torch.Tensor, Tuple[int, int]]:
663
664
665
    if bounding_box.numel() == 0:
        return bounding_box, spatial_size

666
667
668
669
670
671
672
673
674
675
676
    original_shape = bounding_box.shape
    original_dtype = bounding_box.dtype
    bounding_box = bounding_box.clone() if bounding_box.is_floating_point() else bounding_box.float()
    dtype = bounding_box.dtype
    device = bounding_box.device
    bounding_box = (
        convert_format_bounding_box(
            bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
        )
    ).reshape(-1, 4)

677
678
679
    angle, translate, shear, center = _affine_parse_args(
        angle, translate, scale, shear, InterpolationMode.NEAREST, center
    )
680

681
    if center is None:
682
        height, width = spatial_size
683
684
        center = [width * 0.5, height * 0.5]

685
686
687
688
689
690
691
    affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
    transposed_affine_matrix = (
        torch.tensor(
            affine_vector,
            dtype=dtype,
            device=device,
        )
692
        .reshape(2, 3)
693
694
        .T
    )
695
696
697
698
    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
699
    points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
700
    points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
701
    # 2) Now let's transform the points using affine matrix
702
    transformed_points = torch.matmul(points, transposed_affine_matrix)
703
704
    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
705
    transformed_points = transformed_points.reshape(-1, 4, 2)
706
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
707
    out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
708
709
710
711

    if expand:
        # Compute minimum point for transformed image frame:
        # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
712
        height, width = spatial_size
713
714
715
        points = torch.tensor(
            [
                [0.0, 0.0, 1.0],
716
717
718
                [0.0, float(height), 1.0],
                [float(width), float(height), 1.0],
                [float(width), 0.0, 1.0],
719
720
721
722
            ],
            dtype=dtype,
            device=device,
        )
723
        new_points = torch.matmul(points, transposed_affine_matrix)
724
        tr = torch.amin(new_points, dim=0, keepdim=True)
725
        # Translate bounding boxes
726
        out_bboxes.sub_(tr.repeat((1, 2)))
727
728
        # Estimate meta-data for image with inverted=True and with center=[0,0]
        affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
729
        new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
730
        spatial_size = (new_height, new_width)
731

732
733
734
735
736
737
738
    out_bboxes = clamp_bounding_box(out_bboxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size)
    out_bboxes = convert_format_bounding_box(
        out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
    ).reshape(original_shape)

    out_bboxes = out_bboxes.to(original_dtype)
    return out_bboxes, spatial_size
739
740
741
742


def affine_bounding_box(
    bounding_box: torch.Tensor,
743
    format: datapoints.BoundingBoxFormat,
744
    spatial_size: Tuple[int, int],
745
    angle: Union[int, float],
746
747
748
749
750
    translate: List[float],
    scale: float,
    shear: List[float],
    center: Optional[List[float]] = None,
) -> torch.Tensor:
751
752
753
754
755
756
757
758
759
760
761
762
    out_box, _ = _affine_bounding_box_with_expand(
        bounding_box,
        format=format,
        spatial_size=spatial_size,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        center=center,
        expand=False,
    )
    return out_box
763
764


765
766
def affine_mask(
    mask: torch.Tensor,
767
    angle: Union[int, float],
768
769
770
    translate: List[float],
    scale: float,
    shear: List[float],
Philip Meier's avatar
Philip Meier committed
771
    fill: datapoints._FillTypeJIT = None,
772
773
    center: Optional[List[float]] = None,
) -> torch.Tensor:
774
775
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
776
777
778
779
780
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = affine_image_tensor(
781
        mask,
782
783
784
785
786
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=InterpolationMode.NEAREST,
787
        fill=fill,
788
789
790
        center=center,
    )

791
792
793
794
795
    if needs_squeeze:
        output = output.squeeze(0)

    return output

796

797
798
799
800
801
802
def affine_video(
    video: torch.Tensor,
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
803
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
804
    fill: datapoints._FillTypeJIT = None,
805
806
807
808
809
810
811
812
813
814
815
816
817
818
    center: Optional[List[float]] = None,
) -> torch.Tensor:
    return affine_image_tensor(
        video,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=interpolation,
        fill=fill,
        center=center,
    )


819
def affine(
Philip Meier's avatar
Philip Meier committed
820
    inpt: datapoints._InputTypeJIT,
821
    angle: Union[int, float],
822
823
824
    translate: List[float],
    scale: float,
    shear: List[float],
825
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
826
    fill: datapoints._FillTypeJIT = None,
827
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
828
) -> datapoints._InputTypeJIT:
829
830
831
    if not torch.jit.is_scripting():
        _log_api_usage_once(affine)

832
    # TODO: consider deprecating integers from angle and shear on the future
833
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
834
        return affine_image_tensor(
835
836
837
838
839
840
841
842
843
            inpt,
            angle,
            translate=translate,
            scale=scale,
            shear=shear,
            interpolation=interpolation,
            fill=fill,
            center=center,
        )
844
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
845
846
847
        return inpt.affine(
            angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center
        )
848
    elif isinstance(inpt, PIL.Image.Image):
849
        return affine_image_pil(
850
851
852
853
854
855
856
857
858
            inpt,
            angle,
            translate=translate,
            scale=scale,
            shear=shear,
            interpolation=interpolation,
            fill=fill,
            center=center,
        )
859
860
    else:
        raise TypeError(
861
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
862
863
            f"but got {type(inpt)} instead."
        )
864
865


866
def rotate_image_tensor(
867
    image: torch.Tensor,
868
    angle: float,
869
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
870
871
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
872
    fill: datapoints._FillTypeJIT = None,
873
) -> torch.Tensor:
874
875
    interpolation = _check_interpolation(interpolation)

876
877
    shape = image.shape
    num_channels, height, width = shape[-3:]
878

879
880
    center_f = [0.0, 0.0]
    if center is not None:
881
        if expand:
882
            # TODO: Do we actually want to warn, or just document this?
883
            warnings.warn("The provided center argument has no effect on the result if expand is True")
884
885
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
        center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
886
887
888
889

    # due to current incoherence of rotation angle direction between affine and rotate implementations
    # we need to set -angle.
    matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
890

891
    if image.numel() > 0:
892
893
894
895
896
        image = image.reshape(-1, num_channels, height, width)

        _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])

        ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
897
        dtype = image.dtype if torch.is_floating_point(image) else torch.float32
898
899
        theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
        grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh)
900
        output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
901
902

        new_height, new_width = output.shape[-2:]
903
    else:
904
905
        output = image
        new_width, new_height = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
906

907
    return output.reshape(shape[:-3] + (num_channels, new_height, new_width))
908
909


910
@torch.jit.unused
911
def rotate_image_pil(
912
    image: PIL.Image.Image,
913
    angle: float,
914
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
915
916
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
917
    fill: datapoints._FillTypeJIT = None,
918
) -> PIL.Image.Image:
919
920
    interpolation = _check_interpolation(interpolation)

921
    if center is not None and expand:
922
        warnings.warn("The provided center argument has no effect on the result if expand is True")
923
924
        center = None

925
    return _FP.rotate(
926
        image, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
927
928
929
    )


930
931
def rotate_bounding_box(
    bounding_box: torch.Tensor,
932
    format: datapoints.BoundingBoxFormat,
933
    spatial_size: Tuple[int, int],
934
935
936
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
937
) -> Tuple[torch.Tensor, Tuple[int, int]]:
938
939
940
941
    if center is not None and expand:
        warnings.warn("The provided center argument has no effect on the result if expand is True")
        center = None

942
    return _affine_bounding_box_with_expand(
943
        bounding_box,
944
945
        format=format,
        spatial_size=spatial_size,
946
947
948
949
950
951
952
        angle=-angle,
        translate=[0.0, 0.0],
        scale=1.0,
        shear=[0.0, 0.0],
        center=center,
        expand=expand,
    )
953
954


955
956
def rotate_mask(
    mask: torch.Tensor,
957
958
959
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
960
    fill: datapoints._FillTypeJIT = None,
961
) -> torch.Tensor:
962
963
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
964
965
966
967
968
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = rotate_image_tensor(
969
        mask,
970
971
972
        angle=angle,
        expand=expand,
        interpolation=InterpolationMode.NEAREST,
973
        fill=fill,
974
975
976
        center=center,
    )

977
978
979
980
981
    if needs_squeeze:
        output = output.squeeze(0)

    return output

982

983
984
985
def rotate_video(
    video: torch.Tensor,
    angle: float,
986
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
987
988
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
989
    fill: datapoints._FillTypeJIT = None,
990
991
992
993
) -> torch.Tensor:
    return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)


994
def rotate(
Philip Meier's avatar
Philip Meier committed
995
    inpt: datapoints._InputTypeJIT,
996
    angle: float,
997
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
998
999
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
1000
1001
    fill: datapoints._FillTypeJIT = None,
) -> datapoints._InputTypeJIT:
1002
1003
1004
    if not torch.jit.is_scripting():
        _log_api_usage_once(rotate)

1005
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1006
        return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
1007
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1008
        return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
1009
    elif isinstance(inpt, PIL.Image.Image):
1010
        return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
1011
1012
    else:
        raise TypeError(
1013
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1014
1015
            f"but got {type(inpt)} instead."
        )
1016
1017


1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
    if isinstance(padding, int):
        pad_left = pad_right = pad_top = pad_bottom = padding
    elif isinstance(padding, (tuple, list)):
        if len(padding) == 1:
            pad_left = pad_right = pad_top = pad_bottom = padding[0]
        elif len(padding) == 2:
            pad_left = pad_right = padding[0]
            pad_top = pad_bottom = padding[1]
        elif len(padding) == 4:
            pad_left = padding[0]
            pad_top = padding[1]
            pad_right = padding[2]
            pad_bottom = padding[3]
        else:
            raise ValueError(
                f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
            )
    else:
        raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}")

    return [pad_left, pad_right, pad_top, pad_bottom]
1040

1041

1042
def pad_image_tensor(
1043
    image: torch.Tensor,
1044
1045
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1046
1047
    padding_mode: str = "constant",
) -> torch.Tensor:
1048
1049
1050
1051
1052
    # Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses
    # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
    # internally.
    torch_padding = _parse_pad_padding(padding)

1053
    if padding_mode not in ("constant", "edge", "reflect", "symmetric"):
1054
1055
1056
1057
1058
        raise ValueError(
            f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
            f"but got `'{padding_mode}'`."
        )

1059
    if fill is None:
1060
1061
1062
1063
1064
1065
        fill = 0

    if isinstance(fill, (int, float)):
        return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
    elif len(fill) == 1:
        return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode)
1066
    else:
1067
        return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
1068
1069
1070


def _pad_with_scalar_fill(
1071
    image: torch.Tensor,
1072
1073
1074
    torch_padding: List[int],
    fill: Union[int, float],
    padding_mode: str,
1075
) -> torch.Tensor:
1076
1077
    shape = image.shape
    num_channels, height, width = shape[-3:]
1078

1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
    batch_size = 1
    for s in shape[:-3]:
        batch_size *= s

    image = image.reshape(batch_size, num_channels, height, width)

    if padding_mode == "edge":
        # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
        # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
        # name.
        padding_mode = "replicate"

    if padding_mode == "constant":
        image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
    elif padding_mode in ("reflect", "replicate"):
        # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
        # TODO: See https://github.com/pytorch/pytorch/issues/40763
        dtype = image.dtype
        if not image.is_floating_point():
            needs_cast = True
            image = image.to(torch.float32)
        else:
            needs_cast = False
1102

1103
1104
1105
1106
1107
        image = torch_pad(image, torch_padding, mode=padding_mode)

        if needs_cast:
            image = image.to(dtype)
    else:  # padding_mode == "symmetric"
1108
        image = _pad_symmetric(image, torch_padding)
1109
1110

    new_height, new_width = image.shape[-2:]
1111

1112
    return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
1113
1114


1115
# TODO: This should be removed once torch_pad supports non-scalar padding values
1116
def _pad_with_vector_fill(
1117
    image: torch.Tensor,
1118
    torch_padding: List[int],
1119
    fill: List[float],
1120
    padding_mode: str,
1121
1122
1123
1124
) -> torch.Tensor:
    if padding_mode != "constant":
        raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")

1125
1126
    output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
    left, right, top, bottom = torch_padding
1127
    fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139

    if top > 0:
        output[..., :top, :] = fill
    if left > 0:
        output[..., :, :left] = fill
    if bottom > 0:
        output[..., -bottom:, :] = fill
    if right > 0:
        output[..., :, -right:] = fill
    return output


1140
1141
1142
pad_image_pil = _FP.pad


1143
1144
def pad_mask(
    mask: torch.Tensor,
1145
1146
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1147
1148
    padding_mode: str = "constant",
) -> torch.Tensor:
1149
1150
1151
    if fill is None:
        fill = 0

1152
    if isinstance(fill, (tuple, list)):
1153
1154
        raise ValueError("Non-scalar fill value is not supported")

1155
1156
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1157
1158
1159
1160
        needs_squeeze = True
    else:
        needs_squeeze = False

1161
    output = pad_image_tensor(mask, padding=padding, fill=fill, padding_mode=padding_mode)
1162
1163
1164
1165
1166

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1167
1168


1169
def pad_bounding_box(
vfdev's avatar
vfdev committed
1170
    bounding_box: torch.Tensor,
1171
    format: datapoints.BoundingBoxFormat,
1172
    spatial_size: Tuple[int, int],
1173
    padding: List[int],
vfdev's avatar
vfdev committed
1174
    padding_mode: str = "constant",
1175
) -> Tuple[torch.Tensor, Tuple[int, int]]:
vfdev's avatar
vfdev committed
1176
1177
1178
1179
    if padding_mode not in ["constant"]:
        # TODO: add support of other padding modes
        raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")

1180
    left, right, top, bottom = _parse_pad_padding(padding)
1181

1182
    if format == datapoints.BoundingBoxFormat.XYXY:
1183
1184
1185
1186
        pad = [left, top, left, top]
    else:
        pad = [left, top, 0, 0]
    bounding_box = bounding_box + torch.tensor(pad, dtype=bounding_box.dtype, device=bounding_box.device)
1187

1188
    height, width = spatial_size
1189
1190
    height += top + bottom
    width += left + right
1191
    spatial_size = (height, width)
1192

1193
    return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size
1194
1195


1196
1197
def pad_video(
    video: torch.Tensor,
1198
1199
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1200
1201
1202
1203
1204
    padding_mode: str = "constant",
) -> torch.Tensor:
    return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode)


1205
def pad(
Philip Meier's avatar
Philip Meier committed
1206
    inpt: datapoints._InputTypeJIT,
1207
1208
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1209
    padding_mode: str = "constant",
Philip Meier's avatar
Philip Meier committed
1210
) -> datapoints._InputTypeJIT:
1211
1212
1213
    if not torch.jit.is_scripting():
        _log_api_usage_once(pad)

1214
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1215
1216
        return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)

1217
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1218
        return inpt.pad(padding, fill=fill, padding_mode=padding_mode)
1219
    elif isinstance(inpt, PIL.Image.Image):
1220
        return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
1221
1222
    else:
        raise TypeError(
1223
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1224
1225
            f"but got {type(inpt)} instead."
        )
1226
1227


1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
    h, w = image.shape[-2:]

    right = left + width
    bottom = top + height

    if left < 0 or top < 0 or right > w or bottom > h:
        image = image[..., max(top, 0) : bottom, max(left, 0) : right]
        torch_padding = [
            max(min(right, 0) - left, 0),
            max(right - max(w, left), 0),
            max(min(bottom, 0) - top, 0),
            max(bottom - max(h, top), 0),
        ]
        return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
    return image[..., top:bottom, left:right]


1246
1247
1248
crop_image_pil = _FP.crop


1249
1250
def crop_bounding_box(
    bounding_box: torch.Tensor,
1251
    format: datapoints.BoundingBoxFormat,
1252
1253
    top: int,
    left: int,
1254
1255
1256
    height: int,
    width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1257

1258
    # Crop or implicit pad if left and/or top have negative values:
1259
    if format == datapoints.BoundingBoxFormat.XYXY:
1260
        sub = [left, top, left, top]
1261
    else:
1262
1263
1264
        sub = [left, top, 0, 0]

    bounding_box = bounding_box - torch.tensor(sub, dtype=bounding_box.dtype, device=bounding_box.device)
1265
    spatial_size = (height, width)
1266

1267
    return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size
1268
1269


1270
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = crop_image_tensor(mask, top, left, height, width)

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1283
1284


1285
1286
1287
1288
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
    return crop_image_tensor(video, top, left, height, width)


Philip Meier's avatar
Philip Meier committed
1289
def crop(inpt: datapoints._InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints._InputTypeJIT:
1290
1291
1292
    if not torch.jit.is_scripting():
        _log_api_usage_once(crop)

1293
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1294
        return crop_image_tensor(inpt, top, left, height, width)
1295
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1296
        return inpt.crop(top, left, height, width)
1297
    elif isinstance(inpt, PIL.Image.Image):
1298
        return crop_image_pil(inpt, top, left, height, width)
1299
1300
    else:
        raise TypeError(
1301
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1302
1303
            f"but got {type(inpt)} instead."
        )
1304
1305


1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
    # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
    # src/libImaging/Geometry.c#L394

    #
    # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #
    theta1 = torch.tensor(
        [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
    )
    theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)

    d = 0.5
    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
1321
    x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype)
1322
    base_grid[..., 0].copy_(x_grid)
1323
    y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1)
1324
1325
1326
1327
    base_grid[..., 1].copy_(y_grid)
    base_grid[..., 2].fill_(1)

    rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device))
1328
1329
1330
    shape = (1, oh * ow, 3)
    output_grid1 = base_grid.view(shape).bmm(rescaled_theta1)
    output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2))
1331
1332
1333
1334
1335

    output_grid = output_grid1.div_(output_grid2).sub_(1.0)
    return output_grid.view(1, oh, ow, 2)


1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
def _perspective_coefficients(
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]],
) -> List[float]:
    if coefficients is not None:
        if startpoints is not None and endpoints is not None:
            raise ValueError("The startpoints/endpoints and the coefficients shouldn't be defined concurrently.")
        elif len(coefficients) != 8:
            raise ValueError("Argument coefficients should have 8 float values")
        return coefficients
    elif startpoints is not None and endpoints is not None:
        return _get_perspective_coeffs(startpoints, endpoints)
    else:
        raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.")


1353
def perspective_image_tensor(
1354
    image: torch.Tensor,
1355
1356
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1357
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1358
    fill: datapoints._FillTypeJIT = None,
1359
    coefficients: Optional[List[float]] = None,
1360
) -> torch.Tensor:
1361
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1362
1363
    interpolation = _check_interpolation(interpolation)

1364
1365
1366
1367
    if image.numel() == 0:
        return image

    shape = image.shape
1368
    ndim = image.ndim
1369

1370
    if ndim > 4:
1371
        image = image.reshape((-1,) + shape[-3:])
1372
        needs_unsquash = True
1373
1374
1375
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
1376
1377
1378
    else:
        needs_unsquash = False

1379
    _assert_grid_transform_inputs(
1380
1381
1382
1383
1384
1385
1386
1387
        image,
        matrix=None,
        interpolation=interpolation.value,
        fill=fill,
        supported_interpolation_modes=["nearest", "bilinear"],
        coeffs=perspective_coeffs,
    )

1388
    oh, ow = shape[-2:]
1389
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1390
    grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
1391
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1392
1393

    if needs_unsquash:
1394
        output = output.reshape(shape)
1395
1396

    return output
1397
1398


1399
@torch.jit.unused
1400
def perspective_image_pil(
1401
    image: PIL.Image.Image,
1402
1403
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1404
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BICUBIC,
Philip Meier's avatar
Philip Meier committed
1405
    fill: datapoints._FillTypeJIT = None,
1406
    coefficients: Optional[List[float]] = None,
1407
) -> PIL.Image.Image:
1408
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1409
    interpolation = _check_interpolation(interpolation)
1410
    return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
1411
1412


1413
1414
def perspective_bounding_box(
    bounding_box: torch.Tensor,
1415
    format: datapoints.BoundingBoxFormat,
1416
    spatial_size: Tuple[int, int],
1417
1418
1419
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]] = None,
1420
) -> torch.Tensor:
1421
1422
1423
    if bounding_box.numel() == 0:
        return bounding_box

1424
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1425
1426

    original_shape = bounding_box.shape
1427
    # TODO: first cast to float if bbox is int64 before convert_format_bounding_box
1428
    bounding_box = (
1429
        convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
1430
    ).reshape(-1, 4)
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464

    dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
    device = bounding_box.device

    # perspective_coeffs are computed as endpoint -> start point
    # We have to invert perspective_coeffs for bboxes:
    # (x, y) - end point and (x_out, y_out) - start point
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # and we would like to get:
    # x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # and compute inv_coeffs in terms of coeffs

    denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3]
    if denom == 0:
        raise RuntimeError(
            f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. "
            f"Denominator is zero, denom={denom}"
        )

    inv_coeffs = [
        (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom,
        (-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom,
        (perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom,
        (-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom,
        (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom,
        (-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom,
        (-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom,
        (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
    ]

1465
1466
    theta1 = torch.tensor(
        [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
1467
1468
1469
1470
        dtype=dtype,
        device=device,
    )

1471
1472
1473
1474
    theta2 = torch.tensor(
        [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device
    )

1475
1476
1477
1478
    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
1479
    points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1480
1481
1482
1483
1484
    points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
    # 2) Now let's transform the points using perspective matrices
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)

1485
1486
    numer_points = torch.matmul(points, theta1.T)
    denom_points = torch.matmul(points, theta2.T)
1487
    transformed_points = numer_points.div_(denom_points)
1488
1489
1490

    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
1491
    transformed_points = transformed_points.reshape(-1, 4, 2)
1492
1493
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)

1494
1495
1496
1497
1498
    out_bboxes = clamp_bounding_box(
        torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype),
        format=datapoints.BoundingBoxFormat.XYXY,
        spatial_size=spatial_size,
    )
1499
1500
1501

    # out_bboxes should be of shape [N boxes, 4]

1502
    return convert_format_bounding_box(
1503
        out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
1504
    ).reshape(original_shape)
1505
1506


1507
1508
def perspective_mask(
    mask: torch.Tensor,
1509
1510
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
Philip Meier's avatar
Philip Meier committed
1511
    fill: datapoints._FillTypeJIT = None,
1512
    coefficients: Optional[List[float]] = None,
1513
) -> torch.Tensor:
1514
1515
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1516
1517
1518
1519
1520
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = perspective_image_tensor(
1521
        mask, startpoints, endpoints, interpolation=InterpolationMode.NEAREST, fill=fill, coefficients=coefficients
1522
    )
1523

1524
1525
1526
1527
1528
    if needs_squeeze:
        output = output.squeeze(0)

    return output

1529

1530
1531
def perspective_video(
    video: torch.Tensor,
1532
1533
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1534
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1535
    fill: datapoints._FillTypeJIT = None,
1536
    coefficients: Optional[List[float]] = None,
1537
) -> torch.Tensor:
1538
1539
1540
    return perspective_image_tensor(
        video, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
    )
1541
1542


1543
def perspective(
Philip Meier's avatar
Philip Meier committed
1544
    inpt: datapoints._InputTypeJIT,
1545
1546
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1547
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1548
    fill: datapoints._FillTypeJIT = None,
1549
    coefficients: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
1550
) -> datapoints._InputTypeJIT:
1551
1552
    if not torch.jit.is_scripting():
        _log_api_usage_once(perspective)
1553
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1554
1555
1556
        return perspective_image_tensor(
            inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
        )
1557
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1558
1559
1560
        return inpt.perspective(
            startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
        )
1561
    elif isinstance(inpt, PIL.Image.Image):
1562
1563
1564
        return perspective_image_pil(
            inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
        )
1565
1566
    else:
        raise TypeError(
1567
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1568
1569
            f"but got {type(inpt)} instead."
        )
1570
1571


1572
def elastic_image_tensor(
1573
    image: torch.Tensor,
1574
    displacement: torch.Tensor,
1575
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1576
    fill: datapoints._FillTypeJIT = None,
1577
) -> torch.Tensor:
1578
1579
    interpolation = _check_interpolation(interpolation)

1580
1581
1582
1583
    if image.numel() == 0:
        return image

    shape = image.shape
1584
    ndim = image.ndim
1585

1586
    device = image.device
1587
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1588
1589
1590
1591
1592
1593
1594

    # Patch: elastic transform should support (cpu,f16) input
    is_cpu_half = device.type == "cpu" and dtype == torch.float16
    if is_cpu_half:
        image = image.to(torch.float32)
        dtype = torch.float32

1595
1596
1597
    # We are aware that if input image dtype is uint8 and displacement is float64 then
    # displacement will be casted to float32 and all computations will be done with float32
    # We can fix this later if needed
1598

1599
1600
1601
1602
    expected_shape = (1,) + shape[-2:] + (2,)
    if expected_shape != displacement.shape:
        raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")

1603
    if ndim > 4:
1604
        image = image.reshape((-1,) + shape[-3:])
1605
        needs_unsquash = True
1606
1607
1608
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
1609
1610
1611
    else:
        needs_unsquash = False

1612
1613
    if displacement.dtype != dtype or displacement.device != device:
        displacement = displacement.to(dtype=dtype, device=device)
1614

1615
1616
1617
    image_height, image_width = shape[-2:]
    grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement)
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1618
1619

    if needs_unsquash:
1620
        output = output.reshape(shape)
1621

1622
1623
1624
    if is_cpu_half:
        output = output.to(torch.float16)

1625
    return output
1626
1627


1628
@torch.jit.unused
1629
def elastic_image_pil(
1630
    image: PIL.Image.Image,
1631
    displacement: torch.Tensor,
1632
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1633
    fill: datapoints._FillTypeJIT = None,
1634
) -> PIL.Image.Image:
1635
    t_img = pil_to_tensor(image)
1636
    output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill)
1637
    return to_pil_image(output, mode=image.mode)
1638
1639


1640
def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
1641
    sy, sx = size
1642
1643
    base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype)
    x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype)
1644
1645
    base_grid[..., 0].copy_(x_grid)

1646
    y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1)
1647
1648
1649
1650
1651
    base_grid[..., 1].copy_(y_grid)

    return base_grid


1652
1653
def elastic_bounding_box(
    bounding_box: torch.Tensor,
1654
    format: datapoints.BoundingBoxFormat,
1655
    spatial_size: Tuple[int, int],
1656
1657
    displacement: torch.Tensor,
) -> torch.Tensor:
1658
1659
1660
    if bounding_box.numel() == 0:
        return bounding_box

1661
    # TODO: add in docstring about approximation we are doing for grid inversion
1662
1663
1664
1665
1666
    device = bounding_box.device
    dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32

    if displacement.dtype != dtype or displacement.device != device:
        displacement = displacement.to(dtype=dtype, device=device)
1667
1668

    original_shape = bounding_box.shape
1669
    # TODO: first cast to float if bbox is int64 before convert_format_bounding_box
1670
    bounding_box = (
1671
        convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
1672
    ).reshape(-1, 4)
1673

1674
    id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype)
1675
1676
    # We construct an approximation of inverse grid as inv_grid = id_grid - displacement
    # This is not an exact inverse of the grid
1677
    inv_grid = id_grid.sub_(displacement)
1678
1679

    # Get points from bboxes
1680
    points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1681
1682
1683
1684
1685
    if points.is_floating_point():
        points = points.ceil_()
    index_xy = points.to(dtype=torch.long)
    index_x, index_y = index_xy[:, 0], index_xy[:, 1]

1686
    # Transform points:
1687
    t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype)
1688
    transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
1689

1690
    transformed_points = transformed_points.reshape(-1, 4, 2)
1691
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
1692
1693
1694
1695
1696
    out_bboxes = clamp_bounding_box(
        torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype),
        format=datapoints.BoundingBoxFormat.XYXY,
        spatial_size=spatial_size,
    )
1697

1698
    return convert_format_bounding_box(
1699
        out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
1700
    ).reshape(original_shape)
1701
1702


1703
1704
1705
def elastic_mask(
    mask: torch.Tensor,
    displacement: torch.Tensor,
Philip Meier's avatar
Philip Meier committed
1706
    fill: datapoints._FillTypeJIT = None,
1707
) -> torch.Tensor:
1708
1709
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1710
1711
1712
1713
        needs_squeeze = True
    else:
        needs_squeeze = False

1714
    output = elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill)
1715
1716
1717
1718
1719

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1720
1721


1722
1723
1724
def elastic_video(
    video: torch.Tensor,
    displacement: torch.Tensor,
1725
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1726
    fill: datapoints._FillTypeJIT = None,
1727
) -> torch.Tensor:
1728
    return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
1729
1730


1731
def elastic(
Philip Meier's avatar
Philip Meier committed
1732
    inpt: datapoints._InputTypeJIT,
1733
    displacement: torch.Tensor,
1734
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1735
1736
    fill: datapoints._FillTypeJIT = None,
) -> datapoints._InputTypeJIT:
1737
1738
1739
    if not torch.jit.is_scripting():
        _log_api_usage_once(elastic)

1740
1741
1742
    if not isinstance(displacement, torch.Tensor):
        raise TypeError("Argument displacement should be a Tensor")

1743
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1744
        return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
1745
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1746
        return inpt.elastic(displacement, interpolation=interpolation, fill=fill)
1747
    elif isinstance(inpt, PIL.Image.Image):
1748
        return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill)
1749
1750
    else:
        raise TypeError(
1751
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1752
1753
            f"but got {type(inpt)} instead."
        )
1754
1755
1756
1757
1758


elastic_transform = elastic


1759
1760
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
    if isinstance(output_size, numbers.Number):
1761
1762
        s = int(output_size)
        return [s, s]
1763
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
1764
        return [output_size[0], output_size[0]]
1765
1766
    else:
        return list(output_size)
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785


def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]:
    return [
        (crop_width - image_width) // 2 if crop_width > image_width else 0,
        (crop_height - image_height) // 2 if crop_height > image_height else 0,
        (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
        (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
    ]


def _center_crop_compute_crop_anchor(
    crop_height: int, crop_width: int, image_height: int, image_width: int
) -> Tuple[int, int]:
    crop_top = int(round((image_height - crop_height) / 2.0))
    crop_left = int(round((image_width - crop_width) / 2.0))
    return crop_top, crop_left


1786
def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
1787
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1788
1789
1790
1791
    shape = image.shape
    if image.numel() == 0:
        return image.reshape(shape[:-2] + (crop_height, crop_width))
    image_height, image_width = shape[-2:]
1792
1793
1794

    if crop_height > image_height or crop_width > image_width:
        padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
1795
        image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0)
1796

1797
        image_height, image_width = image.shape[-2:]
1798
        if crop_width == image_width and crop_height == image_height:
1799
            return image
1800
1801

    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
1802
    return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]
1803
1804


1805
@torch.jit.unused
1806
def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
1807
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1808
    image_height, image_width = get_spatial_size_image_pil(image)
1809
1810
1811

    if crop_height > image_height or crop_width > image_width:
        padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
1812
        image = pad_image_pil(image, padding_ltrb, fill=0)
1813

1814
        image_height, image_width = get_spatial_size_image_pil(image)
1815
        if crop_width == image_width and crop_height == image_height:
1816
            return image
1817
1818

    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
1819
    return crop_image_pil(image, crop_top, crop_left, crop_height, crop_width)
1820
1821


1822
1823
def center_crop_bounding_box(
    bounding_box: torch.Tensor,
1824
    format: datapoints.BoundingBoxFormat,
1825
    spatial_size: Tuple[int, int],
1826
    output_size: List[int],
1827
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1828
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1829
    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *spatial_size)
1830
    return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width)
1831
1832


1833
1834
1835
def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1836
1837
1838
1839
        needs_squeeze = True
    else:
        needs_squeeze = False

1840
    output = center_crop_image_tensor(image=mask, output_size=output_size)
1841
1842
1843
1844
1845

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1846
1847


1848
1849
1850
1851
def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor:
    return center_crop_image_tensor(video, output_size)


Philip Meier's avatar
Philip Meier committed
1852
def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datapoints._InputTypeJIT:
1853
1854
1855
    if not torch.jit.is_scripting():
        _log_api_usage_once(center_crop)

1856
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1857
        return center_crop_image_tensor(inpt, output_size)
1858
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1859
        return inpt.center_crop(output_size)
1860
    elif isinstance(inpt, PIL.Image.Image):
1861
        return center_crop_image_pil(inpt, output_size)
1862
1863
    else:
        raise TypeError(
1864
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1865
1866
            f"but got {type(inpt)} instead."
        )
1867
1868


1869
def resized_crop_image_tensor(
1870
    image: torch.Tensor,
1871
1872
1873
1874
1875
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1876
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1877
    antialias: Optional[Union[str, bool]] = "warn",
1878
) -> torch.Tensor:
1879
1880
    image = crop_image_tensor(image, top, left, height, width)
    return resize_image_tensor(image, size, interpolation=interpolation, antialias=antialias)
1881
1882


1883
@torch.jit.unused
1884
def resized_crop_image_pil(
1885
    image: PIL.Image.Image,
1886
1887
1888
1889
1890
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1891
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1892
) -> PIL.Image.Image:
1893
1894
    image = crop_image_pil(image, top, left, height, width)
    return resize_image_pil(image, size, interpolation=interpolation)
1895
1896


1897
1898
def resized_crop_bounding_box(
    bounding_box: torch.Tensor,
1899
    format: datapoints.BoundingBoxFormat,
1900
1901
1902
1903
1904
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1905
1906
) -> Tuple[torch.Tensor, Tuple[int, int]]:
    bounding_box, _ = crop_bounding_box(bounding_box, format, top, left, height, width)
1907
    return resize_bounding_box(bounding_box, spatial_size=(height, width), size=size)
1908
1909


1910
def resized_crop_mask(
1911
1912
1913
1914
1915
1916
1917
    mask: torch.Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
) -> torch.Tensor:
1918
1919
    mask = crop_mask(mask, top, left, height, width)
    return resize_mask(mask, size)
1920
1921


1922
1923
1924
1925
1926
1927
1928
def resized_crop_video(
    video: torch.Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1929
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1930
    antialias: Optional[Union[str, bool]] = "warn",
1931
1932
1933
1934
1935
1936
) -> torch.Tensor:
    return resized_crop_image_tensor(
        video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
    )


1937
def resized_crop(
Philip Meier's avatar
Philip Meier committed
1938
    inpt: datapoints._InputTypeJIT,
1939
1940
1941
1942
1943
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1944
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1945
    antialias: Optional[Union[str, bool]] = "warn",
Philip Meier's avatar
Philip Meier committed
1946
) -> datapoints._InputTypeJIT:
1947
1948
1949
    if not torch.jit.is_scripting():
        _log_api_usage_once(resized_crop)

1950
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1951
1952
1953
        return resized_crop_image_tensor(
            inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
        )
1954
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1955
        return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
1956
    elif isinstance(inpt, PIL.Image.Image):
1957
        return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation)
1958
1959
    else:
        raise TypeError(
1960
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1961
1962
            f"but got {type(inpt)} instead."
        )
1963
1964


1965
1966
def _parse_five_crop_size(size: List[int]) -> List[int]:
    if isinstance(size, numbers.Number):
1967
1968
        s = int(size)
        size = [s, s]
1969
    elif isinstance(size, (tuple, list)) and len(size) == 1:
1970
1971
        s = size[0]
        size = [s, s]
1972
1973
1974
1975
1976
1977
1978
1979

    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")

    return size


def five_crop_image_tensor(
1980
    image: torch.Tensor, size: List[int]
1981
1982
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    crop_height, crop_width = _parse_five_crop_size(size)
1983
    image_height, image_width = image.shape[-2:]
1984
1985

    if crop_width > image_width or crop_height > image_height:
1986
        raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
1987

1988
1989
1990
1991
1992
    tl = crop_image_tensor(image, 0, 0, crop_height, crop_width)
    tr = crop_image_tensor(image, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop_image_tensor(image, image_height - crop_height, 0, crop_height, crop_width)
    br = crop_image_tensor(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
    center = center_crop_image_tensor(image, [crop_height, crop_width])
1993
1994
1995
1996

    return tl, tr, bl, br, center


1997
@torch.jit.unused
1998
def five_crop_image_pil(
1999
    image: PIL.Image.Image, size: List[int]
2000
2001
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
    crop_height, crop_width = _parse_five_crop_size(size)
2002
    image_height, image_width = get_spatial_size_image_pil(image)
2003
2004

    if crop_width > image_width or crop_height > image_height:
2005
        raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
2006

2007
2008
2009
2010
2011
    tl = crop_image_pil(image, 0, 0, crop_height, crop_width)
    tr = crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width)
    br = crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
    center = center_crop_image_pil(image, [crop_height, crop_width])
2012
2013
2014
2015

    return tl, tr, bl, br, center


2016
2017
2018
2019
2020
2021
def five_crop_video(
    video: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    return five_crop_image_tensor(video, size)


Philip Meier's avatar
Philip Meier committed
2022
ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]
2023
2024


2025
def five_crop(
2026
2027
    inpt: ImageOrVideoTypeJIT, size: List[int]
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
2028
2029
2030
    if not torch.jit.is_scripting():
        _log_api_usage_once(five_crop)

2031
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
2032
        return five_crop_image_tensor(inpt, size)
2033
    elif isinstance(inpt, datapoints.Image):
2034
        output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size)
2035
2036
        return tuple(datapoints.Image.wrap_like(inpt, item) for item in output)  # type: ignore[return-value]
    elif isinstance(inpt, datapoints.Video):
2037
        output = five_crop_video(inpt.as_subclass(torch.Tensor), size)
2038
        return tuple(datapoints.Video.wrap_like(inpt, item) for item in output)  # type: ignore[return-value]
2039
    elif isinstance(inpt, PIL.Image.Image):
2040
        return five_crop_image_pil(inpt, size)
2041
2042
    else:
        raise TypeError(
2043
            f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
2044
2045
            f"but got {type(inpt)} instead."
        )
2046
2047


Philip Meier's avatar
Philip Meier committed
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
def ten_crop_image_tensor(
    image: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
    non_flipped = five_crop_image_tensor(image, size)
2063
2064

    if vertical_flip:
2065
        image = vertical_flip_image_tensor(image)
2066
    else:
2067
        image = horizontal_flip_image_tensor(image)
2068

Philip Meier's avatar
Philip Meier committed
2069
    flipped = five_crop_image_tensor(image, size)
2070

Philip Meier's avatar
Philip Meier committed
2071
    return non_flipped + flipped
2072
2073


2074
@torch.jit.unused
Philip Meier's avatar
Philip Meier committed
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
def ten_crop_image_pil(
    image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
) -> Tuple[
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
]:
    non_flipped = five_crop_image_pil(image, size)
2090
2091

    if vertical_flip:
2092
        image = vertical_flip_image_pil(image)
2093
    else:
2094
        image = horizontal_flip_image_pil(image)
2095

Philip Meier's avatar
Philip Meier committed
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
    flipped = five_crop_image_pil(image, size)

    return non_flipped + flipped


def ten_crop_video(
    video: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
2115
2116
2117
2118
    return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip)


def ten_crop(
Philip Meier's avatar
Philip Meier committed
2119
    inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], size: List[int], vertical_flip: bool = False
Philip Meier's avatar
Philip Meier committed
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
) -> Tuple[
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
]:
2132
2133
2134
    if not torch.jit.is_scripting():
        _log_api_usage_once(ten_crop)

2135
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
2136
        return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
2137
    elif isinstance(inpt, datapoints.Image):
2138
        output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
2139
        return tuple(datapoints.Image.wrap_like(inpt, item) for item in output)  # type: ignore[return-value]
2140
    elif isinstance(inpt, datapoints.Video):
2141
        output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
2142
        return tuple(datapoints.Video.wrap_like(inpt, item) for item in output)  # type: ignore[return-value]
2143
    elif isinstance(inpt, PIL.Image.Image):
2144
        return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
2145
2146
    else:
        raise TypeError(
2147
            f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
2148
2149
            f"but got {type(inpt)} instead."
        )