_geometry.py 78 KB
Newer Older
1
import math
2
import numbers
3
import warnings
4
from typing import List, Optional, Sequence, Tuple, Union
5
6
7

import PIL.Image
import torch
8
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
9

10
from torchvision import datapoints
11
12
from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _pad_symmetric
13
from torchvision.transforms.functional import (
14
    _check_antialias,
15
    _compute_resized_output_size as __compute_resized_output_size,
16
    _get_perspective_coeffs,
17
    _interpolation_modes_from_int,
18
    InterpolationMode,
19
    pil_modes_mapping,
20
21
    pil_to_tensor,
    to_pil_image,
22
)
23

24
25
from torchvision.utils import _log_api_usage_once

26
from ._meta import clamp_bounding_box, convert_format_bounding_box, get_spatial_size_image_pil
27

28
29
from ._utils import is_simple_tensor

30

31
32
33
34
35
36
37
38
39
40
41
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
    if isinstance(interpolation, int):
        interpolation = _interpolation_modes_from_int(interpolation)
    elif not isinstance(interpolation, InterpolationMode):
        raise ValueError(
            f"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, "
            f"but got {interpolation}."
        )
    return interpolation


42
43
44
45
def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
    return image.flip(-1)


46
47
def horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
    return _FP.hflip(image)
48
49


50
51
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
    return horizontal_flip_image_tensor(mask)
52
53


54
def horizontal_flip_bounding_box(
55
    bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int]
56
57
58
) -> torch.Tensor:
    shape = bounding_box.shape

59
    bounding_box = bounding_box.clone().reshape(-1, 4)
60

61
    if format == datapoints.BoundingBoxFormat.XYXY:
62
        bounding_box[:, [2, 0]] = bounding_box[:, [0, 2]].sub_(spatial_size[1]).neg_()
63
    elif format == datapoints.BoundingBoxFormat.XYWH:
64
        bounding_box[:, 0].add_(bounding_box[:, 2]).sub_(spatial_size[1]).neg_()
65
    else:  # format == datapoints.BoundingBoxFormat.CXCYWH:
66
        bounding_box[:, 0].sub_(spatial_size[1]).neg_()
67

68
    return bounding_box.reshape(shape)
69
70


71
72
73
74
def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
    return horizontal_flip_image_tensor(video)


Philip Meier's avatar
Philip Meier committed
75
def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
76
77
78
    if not torch.jit.is_scripting():
        _log_api_usage_once(horizontal_flip)

79
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
80
        return horizontal_flip_image_tensor(inpt)
81
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
82
        return inpt.horizontal_flip()
83
    elif isinstance(inpt, PIL.Image.Image):
84
        return horizontal_flip_image_pil(inpt)
85
86
    else:
        raise TypeError(
87
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
88
89
            f"but got {type(inpt)} instead."
        )
90
91


92
93
94
95
def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
    return image.flip(-2)


Philip Meier's avatar
Philip Meier committed
96
97
def vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
    return _FP.vflip(image)
98
99


100
101
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
    return vertical_flip_image_tensor(mask)
102
103
104


def vertical_flip_bounding_box(
105
    bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int]
106
107
108
) -> torch.Tensor:
    shape = bounding_box.shape

109
    bounding_box = bounding_box.clone().reshape(-1, 4)
110

111
    if format == datapoints.BoundingBoxFormat.XYXY:
112
        bounding_box[:, [1, 3]] = bounding_box[:, [3, 1]].sub_(spatial_size[0]).neg_()
113
    elif format == datapoints.BoundingBoxFormat.XYWH:
114
        bounding_box[:, 1].add_(bounding_box[:, 3]).sub_(spatial_size[0]).neg_()
115
    else:  # format == datapoints.BoundingBoxFormat.CXCYWH:
116
        bounding_box[:, 1].sub_(spatial_size[0]).neg_()
117

118
    return bounding_box.reshape(shape)
119
120


121
122
123
124
def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
    return vertical_flip_image_tensor(video)


Philip Meier's avatar
Philip Meier committed
125
def vertical_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
126
127
128
    if not torch.jit.is_scripting():
        _log_api_usage_once(vertical_flip)

129
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
130
        return vertical_flip_image_tensor(inpt)
131
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
132
        return inpt.vertical_flip()
133
    elif isinstance(inpt, PIL.Image.Image):
134
        return vertical_flip_image_pil(inpt)
135
136
    else:
        raise TypeError(
137
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
138
139
            f"but got {type(inpt)} instead."
        )
140
141


142
143
144
145
146
147
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
hflip = horizontal_flip
vflip = vertical_flip


148
def _compute_resized_output_size(
149
    spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
150
151
152
) -> List[int]:
    if isinstance(size, int):
        size = [size]
153
154
155
156
157
    elif max_size is not None and len(size) != 1:
        raise ValueError(
            "max_size should only be passed if size specifies the length of the smaller edge, "
            "i.e. size should be an int or a sequence of length 1 in torchscript mode."
        )
158
    return __compute_resized_output_size(spatial_size, size=size, max_size=max_size)
159
160


161
162
163
def resize_image_tensor(
    image: torch.Tensor,
    size: List[int],
164
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
165
    max_size: Optional[int] = None,
166
    antialias: Optional[Union[str, bool]] = "warn",
167
) -> torch.Tensor:
168
    interpolation = _check_interpolation(interpolation)
169
170
    antialias = _check_antialias(img=image, antialias=antialias, interpolation=interpolation)
    assert not isinstance(antialias, str)
171
    antialias = False if antialias is None else antialias
172
173
174
    align_corners: Optional[bool] = None
    if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC:
        align_corners = False
175
176
177
178
    else:
        # The default of antialias should be True from 0.17, so we don't warn or
        # error if other interpolation modes are used. This is documented.
        antialias = False
179

180
    shape = image.shape
181
    numel = image.numel()
182
    num_channels, old_height, old_width = shape[-3:]
vfdev's avatar
vfdev committed
183
    new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
184

185
186
    if (new_height, new_width) == (old_height, old_width):
        return image
187
    elif numel > 0:
188
        image = image.reshape(-1, num_channels, old_height, old_width)
189

190
        dtype = image.dtype
191
192
193
194
        acceptable_dtypes = [torch.float32, torch.float64]
        if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
            # uint8 dtype can be included for cpu and cuda input if nearest mode
            acceptable_dtypes.append(torch.uint8)
195
196
197
198
199
200
201
        elif image.device.type == "cpu":
            # uint8 dtype support for bilinear and bicubic is limited to cpu and
            # according to our benchmarks, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilinear
            if (interpolation == InterpolationMode.BILINEAR and "AVX2" in torch.backends.cpu.get_cpu_capability()) or (
                interpolation == InterpolationMode.BICUBIC
            ):
                acceptable_dtypes.append(torch.uint8)
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

        strides = image.stride()
        if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
            # There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as
            # contiguous even though the input is un-ambiguously channels_last (https://github.com/pytorch/pytorch/issues/68430).
            # In particular this happens for the typical torchvision use-case of single CHW images where we fake the batch dim
            # to become 1CHW. Below, we restride those tensors to trick torch core into properly allocating the output as
            # channels_last, thus preserving the memory format of the input. This is not just for format consistency:
            # for uint8 bilinear images, this also avoids an extra copy (re-packing) of the output and saves time.
            # TODO: when https://github.com/pytorch/pytorch/issues/68430 is fixed (possibly by https://github.com/pytorch/pytorch/pull/100373),
            # we should be able to remove this hack.
            new_strides = list(strides)
            new_strides[0] = numel
            image = image.as_strided((1, num_channels, old_height, old_width), new_strides)

        need_cast = dtype not in acceptable_dtypes
218
219
220
221
        if need_cast:
            image = image.to(dtype=torch.float32)

        image = interpolate(
222
223
            image,
            size=[new_height, new_width],
224
225
            mode=interpolation.value,
            align_corners=align_corners,
226
227
            antialias=antialias,
        )
228

229
230
        if need_cast:
            if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
231
                # This path is hit on non-AVX archs, or on GPU.
232
                image = image.clamp_(min=0, max=255)
233
234
235
            if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
                image = image.round_()
            image = image.to(dtype=dtype)
236

237
    return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
238
239


240
@torch.jit.unused
241
def resize_image_pil(
242
    image: PIL.Image.Image,
243
    size: Union[Sequence[int], int],
244
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
245
246
    max_size: Optional[int] = None,
) -> PIL.Image.Image:
247
248
249
250
251
252
253
    old_height, old_width = image.height, image.width
    new_height, new_width = _compute_resized_output_size(
        (old_height, old_width),
        size=size,  # type: ignore[arg-type]
        max_size=max_size,
    )

254
    interpolation = _check_interpolation(interpolation)
255
256
257
258
259

    if (new_height, new_width) == (old_height, old_width):
        return image

    return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation])
260
261


262
263
264
def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
265
266
267
268
        needs_squeeze = True
    else:
        needs_squeeze = False

269
    output = resize_image_tensor(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
270
271
272
273
274

    if needs_squeeze:
        output = output.squeeze(0)

    return output
275
276


277
def resize_bounding_box(
278
    bounding_box: torch.Tensor, spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
279
) -> Tuple[torch.Tensor, Tuple[int, int]]:
280
281
    old_height, old_width = spatial_size
    new_height, new_width = _compute_resized_output_size(spatial_size, size=size, max_size=max_size)
282
283
284
285

    if (new_height, new_width) == (old_height, old_width):
        return bounding_box, spatial_size

286
287
288
    w_ratio = new_width / old_width
    h_ratio = new_height / old_height
    ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_box.device)
289
    return (
290
        bounding_box.mul(ratios).to(bounding_box.dtype),
291
292
        (new_height, new_width),
    )
293
294


295
296
297
def resize_video(
    video: torch.Tensor,
    size: List[int],
298
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
299
    max_size: Optional[int] = None,
300
    antialias: Optional[Union[str, bool]] = "warn",
301
302
303
304
) -> torch.Tensor:
    return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)


305
def resize(
Philip Meier's avatar
Philip Meier committed
306
    inpt: datapoints._InputTypeJIT,
307
    size: List[int],
308
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
309
    max_size: Optional[int] = None,
310
    antialias: Optional[Union[str, bool]] = "warn",
Philip Meier's avatar
Philip Meier committed
311
) -> datapoints._InputTypeJIT:
312
313
    if not torch.jit.is_scripting():
        _log_api_usage_once(resize)
314
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
315
        return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
316
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
317
        return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
318
    elif isinstance(inpt, PIL.Image.Image):
319
        if antialias is False:
320
321
            warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
        return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size)
322
323
    else:
        raise TypeError(
324
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
325
326
            f"but got {type(inpt)} instead."
        )
327
328


329
def _affine_parse_args(
330
    angle: Union[int, float],
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    translate: List[float],
    scale: float,
    shear: List[float],
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    center: Optional[List[float]] = None,
) -> Tuple[float, List[float], List[float], Optional[List[float]]]:
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if not isinstance(translate, (list, tuple)):
        raise TypeError("Argument translate should be a sequence")

    if len(translate) != 2:
        raise ValueError("Argument translate should be a sequence of length 2")

    if scale <= 0.0:
        raise ValueError("Argument scale should be positive")

    if not isinstance(shear, (numbers.Number, (list, tuple))):
        raise TypeError("Shear should be either a single value or a sequence of two values")

    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")

    if isinstance(angle, int):
        angle = float(angle)

    if isinstance(translate, tuple):
        translate = list(translate)

    if isinstance(shear, numbers.Number):
        shear = [shear, 0.0]

    if isinstance(shear, tuple):
        shear = list(shear)

    if len(shear) == 1:
        shear = [shear[0], shear[0]]

    if len(shear) != 2:
        raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")

373
374
375
376
377
    if center is not None:
        if not isinstance(center, (list, tuple)):
            raise TypeError("Argument center should be a sequence")
        else:
            center = [float(c) for c in center]
378
379
380
381

    return angle, translate, shear, center


382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
def _get_inverse_affine_matrix(
    center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
) -> List[float]:
    # Helper method to compute inverse matrix for affine transformation

    # Pillow requires inverse affine transformation matrix:
    # Affine matrix is : M = T * C * RotateScaleShear * C^-1
    #
    # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
    #       C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
    #       RotateScaleShear is rotation with scale and shear matrix
    #
    #       RotateScaleShear(a, s, (sx, sy)) =
    #       = R(a) * S(s) * SHy(sy) * SHx(sx)
    #       = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
    #         [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
    #         [ 0                    , 0                                      , 1 ]
    # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
    # SHx(s) = [1, -tan(s)] and SHy(s) = [1      , 0]
    #          [0, 1      ]              [-tan(s), 1]
    #
    # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1

    rot = math.radians(angle)
    sx = math.radians(shear[0])
    sy = math.radians(shear[1])

    cx, cy = center
    tx, ty = translate

    # Cached results
    cos_sy = math.cos(sy)
    tan_sx = math.tan(sx)
    rot_minus_sy = rot - sy
    cx_plus_tx = cx + tx
    cy_plus_ty = cy + ty

    # Rotate Scale Shear (RSS) without scaling
    a = math.cos(rot_minus_sy) / cos_sy
    b = -(a * tan_sx + math.sin(rot))
    c = math.sin(rot_minus_sy) / cos_sy
    d = math.cos(rot) - c * tan_sx

    if inverted:
        # Inverted rotation matrix with scale and shear
        # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
        matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0]
        # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
        # and then apply center translation: C * RSS^-1 * C^-1 * T^-1
        matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty
        matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty
    else:
        matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0]
        # Apply inverse of center translation: RSS * C^-1
        # and then apply translation and center : T * C * RSS * C^-1
        matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy
        matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy

    return matrix


def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
    # Inspired of PIL implementation:
    # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054

    # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
    # Points are shifted due to affine matrix torch convention about
    # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
    half_w = 0.5 * w
    half_h = 0.5 * h
    pts = torch.tensor(
        [
            [-half_w, -half_h, 1.0],
            [-half_w, half_h, 1.0],
            [half_w, half_h, 1.0],
            [half_w, -half_h, 1.0],
        ]
    )
    theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
    new_pts = torch.matmul(pts, theta.T)
    min_vals, max_vals = new_pts.aminmax(dim=0)

    # shift points to [0, w] and [0, h] interval to match PIL results
    halfs = torch.tensor((half_w, half_h))
    min_vals.add_(halfs)
    max_vals.add_(halfs)

    # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
    tol = 1e-4
    inv_tol = 1.0 / tol
    cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_()
    cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_()
    size = cmax.sub_(cmin)
    return int(size[0]), int(size[1])  # w, h


def _apply_grid_transform(
Philip Meier's avatar
Philip Meier committed
479
    img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints._FillTypeJIT
480
481
) -> torch.Tensor:

482
483
484
485
    # We are using context knowledge that grid should have float dtype
    fp = img.dtype == grid.dtype
    float_img = img if fp else img.to(grid.dtype)

486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
    shape = float_img.shape
    if shape[0] > 1:
        # Apply same grid to a batch of images
        grid = grid.expand(shape[0], -1, -1, -1)

    # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
    if fill is not None:
        mask = torch.ones((shape[0], 1, shape[2], shape[3]), dtype=float_img.dtype, device=float_img.device)
        float_img = torch.cat((float_img, mask), dim=1)

    float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False)

    # Fill with required color
    if fill is not None:
        float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
        mask = mask.expand_as(float_img)
502
        fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)]  # type: ignore[arg-type]
503
504
505
506
507
508
509
510
511
        fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
        if mode == "nearest":
            bool_mask = mask < 0.5
            float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
        else:  # 'bilinear'
            # The following is mathematically equivalent to:
            # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
            float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)

512
513
514
    img = float_img.round_().to(img.dtype) if not fp else float_img

    return img
515
516
517
518
519
520


def _assert_grid_transform_inputs(
    image: torch.Tensor,
    matrix: Optional[List[float]],
    interpolation: str,
Philip Meier's avatar
Philip Meier committed
521
    fill: datapoints._FillTypeJIT,
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
    supported_interpolation_modes: List[str],
    coeffs: Optional[List[float]] = None,
) -> None:
    if matrix is not None:
        if not isinstance(matrix, list):
            raise TypeError("Argument matrix should be a list")
        elif len(matrix) != 6:
            raise ValueError("Argument matrix should have 6 float values")

    if coeffs is not None and len(coeffs) != 8:
        raise ValueError("Argument coeffs should have 8 float values")

    if fill is not None:
        if isinstance(fill, (tuple, list)):
            length = len(fill)
            num_channels = image.shape[-3]
            if length > 1 and length != num_channels:
                raise ValueError(
                    "The number of elements in 'fill' cannot broadcast to match the number of "
                    f"channels of the image ({length} != {num_channels})"
                )
        elif not isinstance(fill, (int, float)):
            raise ValueError("Argument fill should be either int, float, tuple or list")

    if interpolation not in supported_interpolation_modes:
        raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")


def _affine_grid(
    theta: torch.Tensor,
    w: int,
    h: int,
    ow: int,
    oh: int,
) -> torch.Tensor:
    # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
    # AffineGridGenerator.cpp#L18
    # Difference with AffineGridGenerator is that:
    # 1) we normalize grid values after applying theta
    # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
    dtype = theta.dtype
    device = theta.device

    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
    x_grid = torch.linspace((1.0 - ow) * 0.5, (ow - 1.0) * 0.5, steps=ow, device=device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace((1.0 - oh) * 0.5, (oh - 1.0) * 0.5, steps=oh, device=device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
    base_grid[..., 2].fill_(1)

    rescaled_theta = theta.transpose(1, 2).div_(torch.tensor([0.5 * w, 0.5 * h], dtype=dtype, device=device))
    output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
    return output_grid.view(1, oh, ow, 2)


577
def affine_image_tensor(
578
    image: torch.Tensor,
579
    angle: Union[int, float],
580
581
582
    translate: List[float],
    scale: float,
    shear: List[float],
583
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
584
    fill: datapoints._FillTypeJIT = None,
585
586
    center: Optional[List[float]] = None,
) -> torch.Tensor:
587
588
    interpolation = _check_interpolation(interpolation)

589
590
    if image.numel() == 0:
        return image
591

592
    shape = image.shape
593
    ndim = image.ndim
594

595
596
597
598
599
600
601
602
603
604
    if ndim > 4:
        image = image.reshape((-1,) + shape[-3:])
        needs_unsquash = True
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
    else:
        needs_unsquash = False

    height, width = shape[-2:]
605
606
607
608
609
    angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

    center_f = [0.0, 0.0]
    if center is not None:
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
610
        center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
611

612
    translate_f = [float(t) for t in translate]
613
614
    matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)

615
616
    _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])

617
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
618
619
    theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
    grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
620
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
621
622
623
624
625

    if needs_unsquash:
        output = output.reshape(shape)

    return output
626
627


628
@torch.jit.unused
629
def affine_image_pil(
630
    image: PIL.Image.Image,
631
    angle: Union[int, float],
632
633
634
    translate: List[float],
    scale: float,
    shear: List[float],
635
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
636
    fill: datapoints._FillTypeJIT = None,
637
638
    center: Optional[List[float]] = None,
) -> PIL.Image.Image:
639
    interpolation = _check_interpolation(interpolation)
640
641
642
643
644
645
    angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

    # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
    # it is visually better to estimate the center without 0.5 offset
    # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
    if center is None:
646
        height, width = get_spatial_size_image_pil(image)
647
648
649
        center = [width * 0.5, height * 0.5]
    matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)

650
    return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
651
652


653
def _affine_bounding_box_with_expand(
654
    bounding_box: torch.Tensor,
655
    format: datapoints.BoundingBoxFormat,
656
    spatial_size: Tuple[int, int],
657
658
659
660
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
661
    center: Optional[List[float]] = None,
662
    expand: bool = False,
663
) -> Tuple[torch.Tensor, Tuple[int, int]]:
664
665
666
    if bounding_box.numel() == 0:
        return bounding_box, spatial_size

667
668
669
670
671
672
673
674
675
676
677
    original_shape = bounding_box.shape
    original_dtype = bounding_box.dtype
    bounding_box = bounding_box.clone() if bounding_box.is_floating_point() else bounding_box.float()
    dtype = bounding_box.dtype
    device = bounding_box.device
    bounding_box = (
        convert_format_bounding_box(
            bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
        )
    ).reshape(-1, 4)

678
679
680
    angle, translate, shear, center = _affine_parse_args(
        angle, translate, scale, shear, InterpolationMode.NEAREST, center
    )
681

682
    if center is None:
683
        height, width = spatial_size
684
685
        center = [width * 0.5, height * 0.5]

686
687
688
689
690
691
692
    affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
    transposed_affine_matrix = (
        torch.tensor(
            affine_vector,
            dtype=dtype,
            device=device,
        )
693
        .reshape(2, 3)
694
695
        .T
    )
696
697
698
699
    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
700
    points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
701
    points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
702
    # 2) Now let's transform the points using affine matrix
703
    transformed_points = torch.matmul(points, transposed_affine_matrix)
704
705
    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
706
    transformed_points = transformed_points.reshape(-1, 4, 2)
707
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
708
    out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
709
710
711
712

    if expand:
        # Compute minimum point for transformed image frame:
        # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
713
        height, width = spatial_size
714
715
716
        points = torch.tensor(
            [
                [0.0, 0.0, 1.0],
717
718
719
                [0.0, float(height), 1.0],
                [float(width), float(height), 1.0],
                [float(width), 0.0, 1.0],
720
721
722
723
            ],
            dtype=dtype,
            device=device,
        )
724
        new_points = torch.matmul(points, transposed_affine_matrix)
725
        tr = torch.amin(new_points, dim=0, keepdim=True)
726
        # Translate bounding boxes
727
        out_bboxes.sub_(tr.repeat((1, 2)))
728
729
        # Estimate meta-data for image with inverted=True and with center=[0,0]
        affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
730
        new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
731
        spatial_size = (new_height, new_width)
732

733
734
735
736
737
738
739
    out_bboxes = clamp_bounding_box(out_bboxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size)
    out_bboxes = convert_format_bounding_box(
        out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
    ).reshape(original_shape)

    out_bboxes = out_bboxes.to(original_dtype)
    return out_bboxes, spatial_size
740
741
742
743


def affine_bounding_box(
    bounding_box: torch.Tensor,
744
    format: datapoints.BoundingBoxFormat,
745
    spatial_size: Tuple[int, int],
746
    angle: Union[int, float],
747
748
749
750
751
    translate: List[float],
    scale: float,
    shear: List[float],
    center: Optional[List[float]] = None,
) -> torch.Tensor:
752
753
754
755
756
757
758
759
760
761
762
763
    out_box, _ = _affine_bounding_box_with_expand(
        bounding_box,
        format=format,
        spatial_size=spatial_size,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        center=center,
        expand=False,
    )
    return out_box
764
765


766
767
def affine_mask(
    mask: torch.Tensor,
768
    angle: Union[int, float],
769
770
771
    translate: List[float],
    scale: float,
    shear: List[float],
Philip Meier's avatar
Philip Meier committed
772
    fill: datapoints._FillTypeJIT = None,
773
774
    center: Optional[List[float]] = None,
) -> torch.Tensor:
775
776
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
777
778
779
780
781
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = affine_image_tensor(
782
        mask,
783
784
785
786
787
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=InterpolationMode.NEAREST,
788
        fill=fill,
789
790
791
        center=center,
    )

792
793
794
795
796
    if needs_squeeze:
        output = output.squeeze(0)

    return output

797

798
799
800
801
802
803
def affine_video(
    video: torch.Tensor,
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
804
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
805
    fill: datapoints._FillTypeJIT = None,
806
807
808
809
810
811
812
813
814
815
816
817
818
819
    center: Optional[List[float]] = None,
) -> torch.Tensor:
    return affine_image_tensor(
        video,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=interpolation,
        fill=fill,
        center=center,
    )


820
def affine(
Philip Meier's avatar
Philip Meier committed
821
    inpt: datapoints._InputTypeJIT,
822
    angle: Union[int, float],
823
824
825
    translate: List[float],
    scale: float,
    shear: List[float],
826
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
827
    fill: datapoints._FillTypeJIT = None,
828
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
829
) -> datapoints._InputTypeJIT:
830
831
832
    if not torch.jit.is_scripting():
        _log_api_usage_once(affine)

833
    # TODO: consider deprecating integers from angle and shear on the future
834
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
835
        return affine_image_tensor(
836
837
838
839
840
841
842
843
844
            inpt,
            angle,
            translate=translate,
            scale=scale,
            shear=shear,
            interpolation=interpolation,
            fill=fill,
            center=center,
        )
845
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
846
847
848
        return inpt.affine(
            angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center
        )
849
    elif isinstance(inpt, PIL.Image.Image):
850
        return affine_image_pil(
851
852
853
854
855
856
857
858
859
            inpt,
            angle,
            translate=translate,
            scale=scale,
            shear=shear,
            interpolation=interpolation,
            fill=fill,
            center=center,
        )
860
861
    else:
        raise TypeError(
862
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
863
864
            f"but got {type(inpt)} instead."
        )
865
866


867
def rotate_image_tensor(
868
    image: torch.Tensor,
869
    angle: float,
870
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
871
872
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
873
    fill: datapoints._FillTypeJIT = None,
874
) -> torch.Tensor:
875
876
    interpolation = _check_interpolation(interpolation)

877
878
    shape = image.shape
    num_channels, height, width = shape[-3:]
879

880
881
    center_f = [0.0, 0.0]
    if center is not None:
882
        if expand:
883
            # TODO: Do we actually want to warn, or just document this?
884
            warnings.warn("The provided center argument has no effect on the result if expand is True")
885
886
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
        center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
887
888
889
890

    # due to current incoherence of rotation angle direction between affine and rotate implementations
    # we need to set -angle.
    matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
891

892
    if image.numel() > 0:
893
894
895
896
897
        image = image.reshape(-1, num_channels, height, width)

        _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])

        ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
898
        dtype = image.dtype if torch.is_floating_point(image) else torch.float32
899
900
        theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
        grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh)
901
        output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
902
903

        new_height, new_width = output.shape[-2:]
904
    else:
905
906
        output = image
        new_width, new_height = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
907

908
    return output.reshape(shape[:-3] + (num_channels, new_height, new_width))
909
910


911
@torch.jit.unused
912
def rotate_image_pil(
913
    image: PIL.Image.Image,
914
    angle: float,
915
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
916
917
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
918
    fill: datapoints._FillTypeJIT = None,
919
) -> PIL.Image.Image:
920
921
    interpolation = _check_interpolation(interpolation)

922
    if center is not None and expand:
923
        warnings.warn("The provided center argument has no effect on the result if expand is True")
924

925
    return _FP.rotate(
926
        image, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
927
928
929
    )


930
931
def rotate_bounding_box(
    bounding_box: torch.Tensor,
932
    format: datapoints.BoundingBoxFormat,
933
    spatial_size: Tuple[int, int],
934
935
936
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
937
) -> Tuple[torch.Tensor, Tuple[int, int]]:
938
939
940
    if center is not None and expand:
        warnings.warn("The provided center argument has no effect on the result if expand is True")

941
    return _affine_bounding_box_with_expand(
942
        bounding_box,
943
944
        format=format,
        spatial_size=spatial_size,
945
946
947
948
949
950
951
        angle=-angle,
        translate=[0.0, 0.0],
        scale=1.0,
        shear=[0.0, 0.0],
        center=center,
        expand=expand,
    )
952
953


954
955
def rotate_mask(
    mask: torch.Tensor,
956
957
958
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
959
    fill: datapoints._FillTypeJIT = None,
960
) -> torch.Tensor:
961
962
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
963
964
965
966
967
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = rotate_image_tensor(
968
        mask,
969
970
971
        angle=angle,
        expand=expand,
        interpolation=InterpolationMode.NEAREST,
972
        fill=fill,
973
974
975
        center=center,
    )

976
977
978
979
980
    if needs_squeeze:
        output = output.squeeze(0)

    return output

981

982
983
984
def rotate_video(
    video: torch.Tensor,
    angle: float,
985
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
986
987
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
988
    fill: datapoints._FillTypeJIT = None,
989
990
991
992
) -> torch.Tensor:
    return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)


993
def rotate(
Philip Meier's avatar
Philip Meier committed
994
    inpt: datapoints._InputTypeJIT,
995
    angle: float,
996
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
997
998
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
999
1000
    fill: datapoints._FillTypeJIT = None,
) -> datapoints._InputTypeJIT:
1001
1002
1003
    if not torch.jit.is_scripting():
        _log_api_usage_once(rotate)

1004
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1005
        return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
1006
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1007
        return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
1008
    elif isinstance(inpt, PIL.Image.Image):
1009
        return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
1010
1011
    else:
        raise TypeError(
1012
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1013
1014
            f"but got {type(inpt)} instead."
        )
1015
1016


1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
    if isinstance(padding, int):
        pad_left = pad_right = pad_top = pad_bottom = padding
    elif isinstance(padding, (tuple, list)):
        if len(padding) == 1:
            pad_left = pad_right = pad_top = pad_bottom = padding[0]
        elif len(padding) == 2:
            pad_left = pad_right = padding[0]
            pad_top = pad_bottom = padding[1]
        elif len(padding) == 4:
            pad_left = padding[0]
            pad_top = padding[1]
            pad_right = padding[2]
            pad_bottom = padding[3]
        else:
            raise ValueError(
                f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
            )
    else:
        raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}")

    return [pad_left, pad_right, pad_top, pad_bottom]
1039

1040

1041
def pad_image_tensor(
1042
    image: torch.Tensor,
1043
1044
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1045
1046
    padding_mode: str = "constant",
) -> torch.Tensor:
1047
1048
1049
1050
1051
    # Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses
    # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
    # internally.
    torch_padding = _parse_pad_padding(padding)

1052
    if padding_mode not in ("constant", "edge", "reflect", "symmetric"):
1053
1054
1055
1056
1057
        raise ValueError(
            f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
            f"but got `'{padding_mode}'`."
        )

1058
    if fill is None:
1059
1060
1061
1062
1063
1064
        fill = 0

    if isinstance(fill, (int, float)):
        return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
    elif len(fill) == 1:
        return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode)
1065
    else:
1066
        return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
1067
1068
1069


def _pad_with_scalar_fill(
1070
    image: torch.Tensor,
1071
1072
1073
    torch_padding: List[int],
    fill: Union[int, float],
    padding_mode: str,
1074
) -> torch.Tensor:
1075
1076
    shape = image.shape
    num_channels, height, width = shape[-3:]
1077

1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
    batch_size = 1
    for s in shape[:-3]:
        batch_size *= s

    image = image.reshape(batch_size, num_channels, height, width)

    if padding_mode == "edge":
        # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
        # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
        # name.
        padding_mode = "replicate"

    if padding_mode == "constant":
        image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
    elif padding_mode in ("reflect", "replicate"):
        # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
        # TODO: See https://github.com/pytorch/pytorch/issues/40763
        dtype = image.dtype
        if not image.is_floating_point():
            needs_cast = True
            image = image.to(torch.float32)
        else:
            needs_cast = False
1101

1102
1103
1104
1105
1106
        image = torch_pad(image, torch_padding, mode=padding_mode)

        if needs_cast:
            image = image.to(dtype)
    else:  # padding_mode == "symmetric"
1107
        image = _pad_symmetric(image, torch_padding)
1108
1109

    new_height, new_width = image.shape[-2:]
1110

1111
    return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
1112
1113


1114
# TODO: This should be removed once torch_pad supports non-scalar padding values
1115
def _pad_with_vector_fill(
1116
    image: torch.Tensor,
1117
    torch_padding: List[int],
1118
    fill: List[float],
1119
    padding_mode: str,
1120
1121
1122
1123
) -> torch.Tensor:
    if padding_mode != "constant":
        raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")

1124
1125
    output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
    left, right, top, bottom = torch_padding
1126
    fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138

    if top > 0:
        output[..., :top, :] = fill
    if left > 0:
        output[..., :, :left] = fill
    if bottom > 0:
        output[..., -bottom:, :] = fill
    if right > 0:
        output[..., :, -right:] = fill
    return output


1139
1140
1141
pad_image_pil = _FP.pad


1142
1143
def pad_mask(
    mask: torch.Tensor,
1144
1145
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1146
1147
    padding_mode: str = "constant",
) -> torch.Tensor:
1148
1149
1150
    if fill is None:
        fill = 0

1151
    if isinstance(fill, (tuple, list)):
1152
1153
        raise ValueError("Non-scalar fill value is not supported")

1154
1155
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1156
1157
1158
1159
        needs_squeeze = True
    else:
        needs_squeeze = False

1160
    output = pad_image_tensor(mask, padding=padding, fill=fill, padding_mode=padding_mode)
1161
1162
1163
1164
1165

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1166
1167


1168
def pad_bounding_box(
vfdev's avatar
vfdev committed
1169
    bounding_box: torch.Tensor,
1170
    format: datapoints.BoundingBoxFormat,
1171
    spatial_size: Tuple[int, int],
1172
    padding: List[int],
vfdev's avatar
vfdev committed
1173
    padding_mode: str = "constant",
1174
) -> Tuple[torch.Tensor, Tuple[int, int]]:
vfdev's avatar
vfdev committed
1175
1176
1177
1178
    if padding_mode not in ["constant"]:
        # TODO: add support of other padding modes
        raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")

1179
    left, right, top, bottom = _parse_pad_padding(padding)
1180

1181
    if format == datapoints.BoundingBoxFormat.XYXY:
1182
1183
1184
1185
        pad = [left, top, left, top]
    else:
        pad = [left, top, 0, 0]
    bounding_box = bounding_box + torch.tensor(pad, dtype=bounding_box.dtype, device=bounding_box.device)
1186

1187
    height, width = spatial_size
1188
1189
    height += top + bottom
    width += left + right
1190
    spatial_size = (height, width)
1191

1192
    return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size
1193
1194


1195
1196
def pad_video(
    video: torch.Tensor,
1197
1198
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1199
1200
1201
1202
1203
    padding_mode: str = "constant",
) -> torch.Tensor:
    return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode)


1204
def pad(
Philip Meier's avatar
Philip Meier committed
1205
    inpt: datapoints._InputTypeJIT,
1206
1207
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1208
    padding_mode: str = "constant",
Philip Meier's avatar
Philip Meier committed
1209
) -> datapoints._InputTypeJIT:
1210
1211
1212
    if not torch.jit.is_scripting():
        _log_api_usage_once(pad)

1213
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1214
1215
        return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)

1216
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1217
        return inpt.pad(padding, fill=fill, padding_mode=padding_mode)
1218
    elif isinstance(inpt, PIL.Image.Image):
1219
        return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
1220
1221
    else:
        raise TypeError(
1222
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1223
1224
            f"but got {type(inpt)} instead."
        )
1225
1226


1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
    h, w = image.shape[-2:]

    right = left + width
    bottom = top + height

    if left < 0 or top < 0 or right > w or bottom > h:
        image = image[..., max(top, 0) : bottom, max(left, 0) : right]
        torch_padding = [
            max(min(right, 0) - left, 0),
            max(right - max(w, left), 0),
            max(min(bottom, 0) - top, 0),
            max(bottom - max(h, top), 0),
        ]
        return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
    return image[..., top:bottom, left:right]


1245
1246
1247
crop_image_pil = _FP.crop


1248
1249
def crop_bounding_box(
    bounding_box: torch.Tensor,
1250
    format: datapoints.BoundingBoxFormat,
1251
1252
    top: int,
    left: int,
1253
1254
1255
    height: int,
    width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1256

1257
    # Crop or implicit pad if left and/or top have negative values:
1258
    if format == datapoints.BoundingBoxFormat.XYXY:
1259
        sub = [left, top, left, top]
1260
    else:
1261
1262
1263
        sub = [left, top, 0, 0]

    bounding_box = bounding_box - torch.tensor(sub, dtype=bounding_box.dtype, device=bounding_box.device)
1264
    spatial_size = (height, width)
1265

1266
    return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size
1267
1268


1269
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = crop_image_tensor(mask, top, left, height, width)

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1282
1283


1284
1285
1286
1287
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
    return crop_image_tensor(video, top, left, height, width)


Philip Meier's avatar
Philip Meier committed
1288
def crop(inpt: datapoints._InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints._InputTypeJIT:
1289
1290
1291
    if not torch.jit.is_scripting():
        _log_api_usage_once(crop)

1292
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1293
        return crop_image_tensor(inpt, top, left, height, width)
1294
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1295
        return inpt.crop(top, left, height, width)
1296
    elif isinstance(inpt, PIL.Image.Image):
1297
        return crop_image_pil(inpt, top, left, height, width)
1298
1299
    else:
        raise TypeError(
1300
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1301
1302
            f"but got {type(inpt)} instead."
        )
1303
1304


1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
    # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
    # src/libImaging/Geometry.c#L394

    #
    # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #
    theta1 = torch.tensor(
        [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
    )
    theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)

    d = 0.5
    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
1320
    x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype)
1321
    base_grid[..., 0].copy_(x_grid)
1322
    y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1)
1323
1324
1325
1326
    base_grid[..., 1].copy_(y_grid)
    base_grid[..., 2].fill_(1)

    rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device))
1327
1328
1329
    shape = (1, oh * ow, 3)
    output_grid1 = base_grid.view(shape).bmm(rescaled_theta1)
    output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2))
1330
1331
1332
1333
1334

    output_grid = output_grid1.div_(output_grid2).sub_(1.0)
    return output_grid.view(1, oh, ow, 2)


1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
def _perspective_coefficients(
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]],
) -> List[float]:
    if coefficients is not None:
        if startpoints is not None and endpoints is not None:
            raise ValueError("The startpoints/endpoints and the coefficients shouldn't be defined concurrently.")
        elif len(coefficients) != 8:
            raise ValueError("Argument coefficients should have 8 float values")
        return coefficients
    elif startpoints is not None and endpoints is not None:
        return _get_perspective_coeffs(startpoints, endpoints)
    else:
        raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.")


1352
def perspective_image_tensor(
1353
    image: torch.Tensor,
1354
1355
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1356
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1357
    fill: datapoints._FillTypeJIT = None,
1358
    coefficients: Optional[List[float]] = None,
1359
) -> torch.Tensor:
1360
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1361
1362
    interpolation = _check_interpolation(interpolation)

1363
1364
1365
1366
    if image.numel() == 0:
        return image

    shape = image.shape
1367
    ndim = image.ndim
1368

1369
    if ndim > 4:
1370
        image = image.reshape((-1,) + shape[-3:])
1371
        needs_unsquash = True
1372
1373
1374
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
1375
1376
1377
    else:
        needs_unsquash = False

1378
    _assert_grid_transform_inputs(
1379
1380
1381
1382
1383
1384
1385
1386
        image,
        matrix=None,
        interpolation=interpolation.value,
        fill=fill,
        supported_interpolation_modes=["nearest", "bilinear"],
        coeffs=perspective_coeffs,
    )

1387
    oh, ow = shape[-2:]
1388
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1389
    grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
1390
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1391
1392

    if needs_unsquash:
1393
        output = output.reshape(shape)
1394
1395

    return output
1396
1397


1398
@torch.jit.unused
1399
def perspective_image_pil(
1400
    image: PIL.Image.Image,
1401
1402
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1403
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BICUBIC,
Philip Meier's avatar
Philip Meier committed
1404
    fill: datapoints._FillTypeJIT = None,
1405
    coefficients: Optional[List[float]] = None,
1406
) -> PIL.Image.Image:
1407
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1408
    interpolation = _check_interpolation(interpolation)
1409
    return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
1410
1411


1412
1413
def perspective_bounding_box(
    bounding_box: torch.Tensor,
1414
    format: datapoints.BoundingBoxFormat,
1415
    spatial_size: Tuple[int, int],
1416
1417
1418
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]] = None,
1419
) -> torch.Tensor:
1420
1421
1422
    if bounding_box.numel() == 0:
        return bounding_box

1423
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1424
1425

    original_shape = bounding_box.shape
1426
    # TODO: first cast to float if bbox is int64 before convert_format_bounding_box
1427
    bounding_box = (
1428
        convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
1429
    ).reshape(-1, 4)
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463

    dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
    device = bounding_box.device

    # perspective_coeffs are computed as endpoint -> start point
    # We have to invert perspective_coeffs for bboxes:
    # (x, y) - end point and (x_out, y_out) - start point
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # and we would like to get:
    # x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # and compute inv_coeffs in terms of coeffs

    denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3]
    if denom == 0:
        raise RuntimeError(
            f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. "
            f"Denominator is zero, denom={denom}"
        )

    inv_coeffs = [
        (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom,
        (-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom,
        (perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom,
        (-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom,
        (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom,
        (-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom,
        (-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom,
        (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
    ]

1464
1465
    theta1 = torch.tensor(
        [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
1466
1467
1468
1469
        dtype=dtype,
        device=device,
    )

1470
1471
1472
1473
    theta2 = torch.tensor(
        [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device
    )

1474
1475
1476
1477
    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
1478
    points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1479
1480
1481
1482
1483
    points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
    # 2) Now let's transform the points using perspective matrices
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)

1484
1485
    numer_points = torch.matmul(points, theta1.T)
    denom_points = torch.matmul(points, theta2.T)
1486
    transformed_points = numer_points.div_(denom_points)
1487
1488
1489

    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
1490
    transformed_points = transformed_points.reshape(-1, 4, 2)
1491
1492
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)

1493
1494
1495
1496
1497
    out_bboxes = clamp_bounding_box(
        torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype),
        format=datapoints.BoundingBoxFormat.XYXY,
        spatial_size=spatial_size,
    )
1498
1499
1500

    # out_bboxes should be of shape [N boxes, 4]

1501
    return convert_format_bounding_box(
1502
        out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
1503
    ).reshape(original_shape)
1504
1505


1506
1507
def perspective_mask(
    mask: torch.Tensor,
1508
1509
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
Philip Meier's avatar
Philip Meier committed
1510
    fill: datapoints._FillTypeJIT = None,
1511
    coefficients: Optional[List[float]] = None,
1512
) -> torch.Tensor:
1513
1514
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1515
1516
1517
1518
1519
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = perspective_image_tensor(
1520
        mask, startpoints, endpoints, interpolation=InterpolationMode.NEAREST, fill=fill, coefficients=coefficients
1521
    )
1522

1523
1524
1525
1526
1527
    if needs_squeeze:
        output = output.squeeze(0)

    return output

1528

1529
1530
def perspective_video(
    video: torch.Tensor,
1531
1532
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1533
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1534
    fill: datapoints._FillTypeJIT = None,
1535
    coefficients: Optional[List[float]] = None,
1536
) -> torch.Tensor:
1537
1538
1539
    return perspective_image_tensor(
        video, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
    )
1540
1541


1542
def perspective(
Philip Meier's avatar
Philip Meier committed
1543
    inpt: datapoints._InputTypeJIT,
1544
1545
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1546
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1547
    fill: datapoints._FillTypeJIT = None,
1548
    coefficients: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
1549
) -> datapoints._InputTypeJIT:
1550
1551
    if not torch.jit.is_scripting():
        _log_api_usage_once(perspective)
1552
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1553
1554
1555
        return perspective_image_tensor(
            inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
        )
1556
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1557
1558
1559
        return inpt.perspective(
            startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
        )
1560
    elif isinstance(inpt, PIL.Image.Image):
1561
1562
1563
        return perspective_image_pil(
            inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
        )
1564
1565
    else:
        raise TypeError(
1566
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1567
1568
            f"but got {type(inpt)} instead."
        )
1569
1570


1571
def elastic_image_tensor(
1572
    image: torch.Tensor,
1573
    displacement: torch.Tensor,
1574
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1575
    fill: datapoints._FillTypeJIT = None,
1576
) -> torch.Tensor:
1577
1578
    interpolation = _check_interpolation(interpolation)

1579
1580
1581
1582
    if image.numel() == 0:
        return image

    shape = image.shape
1583
    ndim = image.ndim
1584

1585
    device = image.device
1586
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1587
1588
1589
1590
1591
1592
1593

    # Patch: elastic transform should support (cpu,f16) input
    is_cpu_half = device.type == "cpu" and dtype == torch.float16
    if is_cpu_half:
        image = image.to(torch.float32)
        dtype = torch.float32

1594
1595
1596
    # We are aware that if input image dtype is uint8 and displacement is float64 then
    # displacement will be casted to float32 and all computations will be done with float32
    # We can fix this later if needed
1597

1598
1599
1600
1601
    expected_shape = (1,) + shape[-2:] + (2,)
    if expected_shape != displacement.shape:
        raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")

1602
    if ndim > 4:
1603
        image = image.reshape((-1,) + shape[-3:])
1604
        needs_unsquash = True
1605
1606
1607
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
1608
1609
1610
    else:
        needs_unsquash = False

1611
1612
    if displacement.dtype != dtype or displacement.device != device:
        displacement = displacement.to(dtype=dtype, device=device)
1613

1614
1615
1616
    image_height, image_width = shape[-2:]
    grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement)
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1617
1618

    if needs_unsquash:
1619
        output = output.reshape(shape)
1620

1621
1622
1623
    if is_cpu_half:
        output = output.to(torch.float16)

1624
    return output
1625
1626


1627
@torch.jit.unused
1628
def elastic_image_pil(
1629
    image: PIL.Image.Image,
1630
    displacement: torch.Tensor,
1631
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1632
    fill: datapoints._FillTypeJIT = None,
1633
) -> PIL.Image.Image:
1634
    t_img = pil_to_tensor(image)
1635
    output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill)
1636
    return to_pil_image(output, mode=image.mode)
1637
1638


1639
def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
1640
    sy, sx = size
1641
1642
    base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype)
    x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype)
1643
1644
    base_grid[..., 0].copy_(x_grid)

1645
    y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1)
1646
1647
1648
1649
1650
    base_grid[..., 1].copy_(y_grid)

    return base_grid


1651
1652
def elastic_bounding_box(
    bounding_box: torch.Tensor,
1653
    format: datapoints.BoundingBoxFormat,
1654
    spatial_size: Tuple[int, int],
1655
1656
    displacement: torch.Tensor,
) -> torch.Tensor:
1657
1658
1659
    if bounding_box.numel() == 0:
        return bounding_box

1660
    # TODO: add in docstring about approximation we are doing for grid inversion
1661
1662
1663
1664
1665
    device = bounding_box.device
    dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32

    if displacement.dtype != dtype or displacement.device != device:
        displacement = displacement.to(dtype=dtype, device=device)
1666
1667

    original_shape = bounding_box.shape
1668
    # TODO: first cast to float if bbox is int64 before convert_format_bounding_box
1669
    bounding_box = (
1670
        convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
1671
    ).reshape(-1, 4)
1672

1673
    id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype)
1674
1675
    # We construct an approximation of inverse grid as inv_grid = id_grid - displacement
    # This is not an exact inverse of the grid
1676
    inv_grid = id_grid.sub_(displacement)
1677
1678

    # Get points from bboxes
1679
    points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1680
1681
1682
1683
1684
    if points.is_floating_point():
        points = points.ceil_()
    index_xy = points.to(dtype=torch.long)
    index_x, index_y = index_xy[:, 0], index_xy[:, 1]

1685
    # Transform points:
1686
    t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype)
1687
    transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
1688

1689
    transformed_points = transformed_points.reshape(-1, 4, 2)
1690
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
1691
1692
1693
1694
1695
    out_bboxes = clamp_bounding_box(
        torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype),
        format=datapoints.BoundingBoxFormat.XYXY,
        spatial_size=spatial_size,
    )
1696

1697
    return convert_format_bounding_box(
1698
        out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
1699
    ).reshape(original_shape)
1700
1701


1702
1703
1704
def elastic_mask(
    mask: torch.Tensor,
    displacement: torch.Tensor,
Philip Meier's avatar
Philip Meier committed
1705
    fill: datapoints._FillTypeJIT = None,
1706
) -> torch.Tensor:
1707
1708
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1709
1710
1711
1712
        needs_squeeze = True
    else:
        needs_squeeze = False

1713
    output = elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill)
1714
1715
1716
1717
1718

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1719
1720


1721
1722
1723
def elastic_video(
    video: torch.Tensor,
    displacement: torch.Tensor,
1724
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1725
    fill: datapoints._FillTypeJIT = None,
1726
) -> torch.Tensor:
1727
    return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
1728
1729


1730
def elastic(
Philip Meier's avatar
Philip Meier committed
1731
    inpt: datapoints._InputTypeJIT,
1732
    displacement: torch.Tensor,
1733
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1734
1735
    fill: datapoints._FillTypeJIT = None,
) -> datapoints._InputTypeJIT:
1736
1737
1738
    if not torch.jit.is_scripting():
        _log_api_usage_once(elastic)

1739
1740
1741
    if not isinstance(displacement, torch.Tensor):
        raise TypeError("Argument displacement should be a Tensor")

1742
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1743
        return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
1744
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1745
        return inpt.elastic(displacement, interpolation=interpolation, fill=fill)
1746
    elif isinstance(inpt, PIL.Image.Image):
1747
        return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill)
1748
1749
    else:
        raise TypeError(
1750
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1751
1752
            f"but got {type(inpt)} instead."
        )
1753
1754
1755
1756
1757


elastic_transform = elastic


1758
1759
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
    if isinstance(output_size, numbers.Number):
1760
1761
        s = int(output_size)
        return [s, s]
1762
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
1763
        return [output_size[0], output_size[0]]
1764
1765
    else:
        return list(output_size)
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784


def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]:
    return [
        (crop_width - image_width) // 2 if crop_width > image_width else 0,
        (crop_height - image_height) // 2 if crop_height > image_height else 0,
        (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
        (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
    ]


def _center_crop_compute_crop_anchor(
    crop_height: int, crop_width: int, image_height: int, image_width: int
) -> Tuple[int, int]:
    crop_top = int(round((image_height - crop_height) / 2.0))
    crop_left = int(round((image_width - crop_width) / 2.0))
    return crop_top, crop_left


1785
def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
1786
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1787
1788
1789
1790
    shape = image.shape
    if image.numel() == 0:
        return image.reshape(shape[:-2] + (crop_height, crop_width))
    image_height, image_width = shape[-2:]
1791
1792
1793

    if crop_height > image_height or crop_width > image_width:
        padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
1794
        image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0)
1795

1796
        image_height, image_width = image.shape[-2:]
1797
        if crop_width == image_width and crop_height == image_height:
1798
            return image
1799
1800

    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
1801
    return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]
1802
1803


1804
@torch.jit.unused
1805
def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
1806
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1807
    image_height, image_width = get_spatial_size_image_pil(image)
1808
1809
1810

    if crop_height > image_height or crop_width > image_width:
        padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
1811
        image = pad_image_pil(image, padding_ltrb, fill=0)
1812

1813
        image_height, image_width = get_spatial_size_image_pil(image)
1814
        if crop_width == image_width and crop_height == image_height:
1815
            return image
1816
1817

    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
1818
    return crop_image_pil(image, crop_top, crop_left, crop_height, crop_width)
1819
1820


1821
1822
def center_crop_bounding_box(
    bounding_box: torch.Tensor,
1823
    format: datapoints.BoundingBoxFormat,
1824
    spatial_size: Tuple[int, int],
1825
    output_size: List[int],
1826
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1827
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1828
    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *spatial_size)
1829
    return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width)
1830
1831


1832
1833
1834
def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1835
1836
1837
1838
        needs_squeeze = True
    else:
        needs_squeeze = False

1839
    output = center_crop_image_tensor(image=mask, output_size=output_size)
1840
1841
1842
1843
1844

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1845
1846


1847
1848
1849
1850
def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor:
    return center_crop_image_tensor(video, output_size)


Philip Meier's avatar
Philip Meier committed
1851
def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datapoints._InputTypeJIT:
1852
1853
1854
    if not torch.jit.is_scripting():
        _log_api_usage_once(center_crop)

1855
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1856
        return center_crop_image_tensor(inpt, output_size)
1857
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1858
        return inpt.center_crop(output_size)
1859
    elif isinstance(inpt, PIL.Image.Image):
1860
        return center_crop_image_pil(inpt, output_size)
1861
1862
    else:
        raise TypeError(
1863
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1864
1865
            f"but got {type(inpt)} instead."
        )
1866
1867


1868
def resized_crop_image_tensor(
1869
    image: torch.Tensor,
1870
1871
1872
1873
1874
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1875
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1876
    antialias: Optional[Union[str, bool]] = "warn",
1877
) -> torch.Tensor:
1878
1879
    image = crop_image_tensor(image, top, left, height, width)
    return resize_image_tensor(image, size, interpolation=interpolation, antialias=antialias)
1880
1881


1882
@torch.jit.unused
1883
def resized_crop_image_pil(
1884
    image: PIL.Image.Image,
1885
1886
1887
1888
1889
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1890
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1891
) -> PIL.Image.Image:
1892
1893
    image = crop_image_pil(image, top, left, height, width)
    return resize_image_pil(image, size, interpolation=interpolation)
1894
1895


1896
1897
def resized_crop_bounding_box(
    bounding_box: torch.Tensor,
1898
    format: datapoints.BoundingBoxFormat,
1899
1900
1901
1902
1903
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1904
1905
) -> Tuple[torch.Tensor, Tuple[int, int]]:
    bounding_box, _ = crop_bounding_box(bounding_box, format, top, left, height, width)
1906
    return resize_bounding_box(bounding_box, spatial_size=(height, width), size=size)
1907
1908


1909
def resized_crop_mask(
1910
1911
1912
1913
1914
1915
1916
    mask: torch.Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
) -> torch.Tensor:
1917
1918
    mask = crop_mask(mask, top, left, height, width)
    return resize_mask(mask, size)
1919
1920


1921
1922
1923
1924
1925
1926
1927
def resized_crop_video(
    video: torch.Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1928
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1929
    antialias: Optional[Union[str, bool]] = "warn",
1930
1931
1932
1933
1934
1935
) -> torch.Tensor:
    return resized_crop_image_tensor(
        video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
    )


1936
def resized_crop(
Philip Meier's avatar
Philip Meier committed
1937
    inpt: datapoints._InputTypeJIT,
1938
1939
1940
1941
1942
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1943
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1944
    antialias: Optional[Union[str, bool]] = "warn",
Philip Meier's avatar
Philip Meier committed
1945
) -> datapoints._InputTypeJIT:
1946
1947
1948
    if not torch.jit.is_scripting():
        _log_api_usage_once(resized_crop)

1949
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1950
1951
1952
        return resized_crop_image_tensor(
            inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
        )
1953
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1954
        return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
1955
    elif isinstance(inpt, PIL.Image.Image):
1956
        return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation)
1957
1958
    else:
        raise TypeError(
1959
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1960
1961
            f"but got {type(inpt)} instead."
        )
1962
1963


1964
1965
def _parse_five_crop_size(size: List[int]) -> List[int]:
    if isinstance(size, numbers.Number):
1966
1967
        s = int(size)
        size = [s, s]
1968
    elif isinstance(size, (tuple, list)) and len(size) == 1:
1969
1970
        s = size[0]
        size = [s, s]
1971
1972
1973
1974
1975
1976
1977
1978

    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")

    return size


def five_crop_image_tensor(
1979
    image: torch.Tensor, size: List[int]
1980
1981
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    crop_height, crop_width = _parse_five_crop_size(size)
1982
    image_height, image_width = image.shape[-2:]
1983
1984

    if crop_width > image_width or crop_height > image_height:
1985
        raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
1986

1987
1988
1989
1990
1991
    tl = crop_image_tensor(image, 0, 0, crop_height, crop_width)
    tr = crop_image_tensor(image, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop_image_tensor(image, image_height - crop_height, 0, crop_height, crop_width)
    br = crop_image_tensor(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
    center = center_crop_image_tensor(image, [crop_height, crop_width])
1992
1993
1994
1995

    return tl, tr, bl, br, center


1996
@torch.jit.unused
1997
def five_crop_image_pil(
1998
    image: PIL.Image.Image, size: List[int]
1999
2000
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
    crop_height, crop_width = _parse_five_crop_size(size)
2001
    image_height, image_width = get_spatial_size_image_pil(image)
2002
2003

    if crop_width > image_width or crop_height > image_height:
2004
        raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
2005

2006
2007
2008
2009
2010
    tl = crop_image_pil(image, 0, 0, crop_height, crop_width)
    tr = crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width)
    br = crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
    center = center_crop_image_pil(image, [crop_height, crop_width])
2011
2012
2013
2014

    return tl, tr, bl, br, center


2015
2016
2017
2018
2019
2020
def five_crop_video(
    video: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    return five_crop_image_tensor(video, size)


Philip Meier's avatar
Philip Meier committed
2021
ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]
2022
2023


2024
def five_crop(
2025
2026
    inpt: ImageOrVideoTypeJIT, size: List[int]
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
2027
2028
2029
    if not torch.jit.is_scripting():
        _log_api_usage_once(five_crop)

2030
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
2031
        return five_crop_image_tensor(inpt, size)
2032
    elif isinstance(inpt, datapoints.Image):
2033
        output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size)
2034
2035
        return tuple(datapoints.Image.wrap_like(inpt, item) for item in output)  # type: ignore[return-value]
    elif isinstance(inpt, datapoints.Video):
2036
        output = five_crop_video(inpt.as_subclass(torch.Tensor), size)
2037
        return tuple(datapoints.Video.wrap_like(inpt, item) for item in output)  # type: ignore[return-value]
2038
    elif isinstance(inpt, PIL.Image.Image):
2039
        return five_crop_image_pil(inpt, size)
2040
2041
    else:
        raise TypeError(
2042
            f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
2043
2044
            f"but got {type(inpt)} instead."
        )
2045
2046


Philip Meier's avatar
Philip Meier committed
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
def ten_crop_image_tensor(
    image: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
    non_flipped = five_crop_image_tensor(image, size)
2062
2063

    if vertical_flip:
2064
        image = vertical_flip_image_tensor(image)
2065
    else:
2066
        image = horizontal_flip_image_tensor(image)
2067

Philip Meier's avatar
Philip Meier committed
2068
    flipped = five_crop_image_tensor(image, size)
2069

Philip Meier's avatar
Philip Meier committed
2070
    return non_flipped + flipped
2071
2072


2073
@torch.jit.unused
Philip Meier's avatar
Philip Meier committed
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
def ten_crop_image_pil(
    image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
) -> Tuple[
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
]:
    non_flipped = five_crop_image_pil(image, size)
2089
2090

    if vertical_flip:
2091
        image = vertical_flip_image_pil(image)
2092
    else:
2093
        image = horizontal_flip_image_pil(image)
2094

Philip Meier's avatar
Philip Meier committed
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
    flipped = five_crop_image_pil(image, size)

    return non_flipped + flipped


def ten_crop_video(
    video: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
2114
2115
2116
2117
    return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip)


def ten_crop(
Philip Meier's avatar
Philip Meier committed
2118
    inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], size: List[int], vertical_flip: bool = False
Philip Meier's avatar
Philip Meier committed
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
) -> Tuple[
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
]:
2131
2132
2133
    if not torch.jit.is_scripting():
        _log_api_usage_once(ten_crop)

2134
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
2135
        return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
2136
    elif isinstance(inpt, datapoints.Image):
2137
        output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
2138
        return tuple(datapoints.Image.wrap_like(inpt, item) for item in output)  # type: ignore[return-value]
2139
    elif isinstance(inpt, datapoints.Video):
2140
        output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
2141
        return tuple(datapoints.Video.wrap_like(inpt, item) for item in output)  # type: ignore[return-value]
2142
    elif isinstance(inpt, PIL.Image.Image):
2143
        return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
2144
2145
    else:
        raise TypeError(
2146
            f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
2147
2148
            f"but got {type(inpt)} instead."
        )