_geometry.py 75.8 KB
Newer Older
1
import math
2
import numbers
3
import warnings
4
from typing import List, Optional, Sequence, Tuple, Union
5
6
7

import PIL.Image
import torch
8
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
9

10
from torchvision import datapoints
11
12
from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _pad_symmetric
13
from torchvision.transforms.functional import (
14
    _check_antialias,
15
    _compute_resized_output_size as __compute_resized_output_size,
16
    _get_perspective_coeffs,
17
    _interpolation_modes_from_int,
18
    InterpolationMode,
19
    pil_modes_mapping,
20
21
    pil_to_tensor,
    to_pil_image,
22
)
23

24
25
from torchvision.utils import _log_api_usage_once

26
from ._meta import clamp_bounding_box, convert_format_bounding_box, get_spatial_size_image_pil
27

28
29
from ._utils import is_simple_tensor

30

31
32
33
34
35
36
37
38
39
40
41
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
    if isinstance(interpolation, int):
        interpolation = _interpolation_modes_from_int(interpolation)
    elif not isinstance(interpolation, InterpolationMode):
        raise ValueError(
            f"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, "
            f"but got {interpolation}."
        )
    return interpolation


42
43
44
45
def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
    return image.flip(-1)


46
47
48
horizontal_flip_image_pil = _FP.hflip


49
50
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
    return horizontal_flip_image_tensor(mask)
51
52


53
def horizontal_flip_bounding_box(
54
    bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int]
55
56
57
) -> torch.Tensor:
    shape = bounding_box.shape

58
    bounding_box = bounding_box.clone().reshape(-1, 4)
59

60
    if format == datapoints.BoundingBoxFormat.XYXY:
61
        bounding_box[:, [2, 0]] = bounding_box[:, [0, 2]].sub_(spatial_size[1]).neg_()
62
    elif format == datapoints.BoundingBoxFormat.XYWH:
63
        bounding_box[:, 0].add_(bounding_box[:, 2]).sub_(spatial_size[1]).neg_()
64
    else:  # format == datapoints.BoundingBoxFormat.CXCYWH:
65
        bounding_box[:, 0].sub_(spatial_size[1]).neg_()
66

67
    return bounding_box.reshape(shape)
68
69


70
71
72
73
def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
    return horizontal_flip_image_tensor(video)


Philip Meier's avatar
Philip Meier committed
74
def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
75
76
77
    if not torch.jit.is_scripting():
        _log_api_usage_once(horizontal_flip)

78
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
79
        return horizontal_flip_image_tensor(inpt)
80
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
81
        return inpt.horizontal_flip()
82
    elif isinstance(inpt, PIL.Image.Image):
83
        return horizontal_flip_image_pil(inpt)
84
85
    else:
        raise TypeError(
86
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
87
88
            f"but got {type(inpt)} instead."
        )
89
90


91
92
93
94
def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
    return image.flip(-2)


95
96
97
vertical_flip_image_pil = _FP.vflip


98
99
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
    return vertical_flip_image_tensor(mask)
100
101
102


def vertical_flip_bounding_box(
103
    bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int]
104
105
106
) -> torch.Tensor:
    shape = bounding_box.shape

107
    bounding_box = bounding_box.clone().reshape(-1, 4)
108

109
    if format == datapoints.BoundingBoxFormat.XYXY:
110
        bounding_box[:, [1, 3]] = bounding_box[:, [3, 1]].sub_(spatial_size[0]).neg_()
111
    elif format == datapoints.BoundingBoxFormat.XYWH:
112
        bounding_box[:, 1].add_(bounding_box[:, 3]).sub_(spatial_size[0]).neg_()
113
    else:  # format == datapoints.BoundingBoxFormat.CXCYWH:
114
        bounding_box[:, 1].sub_(spatial_size[0]).neg_()
115

116
    return bounding_box.reshape(shape)
117
118


119
120
121
122
def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
    return vertical_flip_image_tensor(video)


Philip Meier's avatar
Philip Meier committed
123
def vertical_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
124
125
126
    if not torch.jit.is_scripting():
        _log_api_usage_once(vertical_flip)

127
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
128
        return vertical_flip_image_tensor(inpt)
129
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
130
        return inpt.vertical_flip()
131
    elif isinstance(inpt, PIL.Image.Image):
132
        return vertical_flip_image_pil(inpt)
133
134
    else:
        raise TypeError(
135
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
136
137
            f"but got {type(inpt)} instead."
        )
138
139


140
141
142
143
144
145
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
hflip = horizontal_flip
vflip = vertical_flip


146
def _compute_resized_output_size(
147
    spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
148
149
150
) -> List[int]:
    if isinstance(size, int):
        size = [size]
151
152
153
154
155
    elif max_size is not None and len(size) != 1:
        raise ValueError(
            "max_size should only be passed if size specifies the length of the smaller edge, "
            "i.e. size should be an int or a sequence of length 1 in torchscript mode."
        )
156
    return __compute_resized_output_size(spatial_size, size=size, max_size=max_size)
157
158


159
160
161
def resize_image_tensor(
    image: torch.Tensor,
    size: List[int],
162
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
163
    max_size: Optional[int] = None,
164
    antialias: Optional[Union[str, bool]] = "warn",
165
) -> torch.Tensor:
166
    interpolation = _check_interpolation(interpolation)
167
168
    antialias = _check_antialias(img=image, antialias=antialias, interpolation=interpolation)
    assert not isinstance(antialias, str)
169
    antialias = False if antialias is None else antialias
170
171
172
    align_corners: Optional[bool] = None
    if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC:
        align_corners = False
173
174
175
176
    else:
        # The default of antialias should be True from 0.17, so we don't warn or
        # error if other interpolation modes are used. This is documented.
        antialias = False
177

178
179
    shape = image.shape
    num_channels, old_height, old_width = shape[-3:]
vfdev's avatar
vfdev committed
180
    new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
181

182
183
184
    if (new_height, new_width) == (old_height, old_width):
        return image
    elif image.numel() > 0:
185
        image = image.reshape(-1, num_channels, old_height, old_width)
186

187
188
189
190
191
192
        dtype = image.dtype
        need_cast = dtype not in (torch.float32, torch.float64)
        if need_cast:
            image = image.to(dtype=torch.float32)

        image = interpolate(
193
194
            image,
            size=[new_height, new_width],
195
196
            mode=interpolation.value,
            align_corners=align_corners,
197
198
            antialias=antialias,
        )
199

200
201
202
203
204
        if need_cast:
            if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
                image = image.clamp_(min=0, max=255)
            image = image.round_().to(dtype=dtype)

205
    return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
206
207


208
@torch.jit.unused
209
def resize_image_pil(
210
    image: PIL.Image.Image,
211
    size: Union[Sequence[int], int],
212
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
213
214
    max_size: Optional[int] = None,
) -> PIL.Image.Image:
215
216
217
218
219
220
221
    old_height, old_width = image.height, image.width
    new_height, new_width = _compute_resized_output_size(
        (old_height, old_width),
        size=size,  # type: ignore[arg-type]
        max_size=max_size,
    )

222
    interpolation = _check_interpolation(interpolation)
223
224
225
226
227

    if (new_height, new_width) == (old_height, old_width):
        return image

    return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation])
228
229


230
231
232
def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
233
234
235
236
        needs_squeeze = True
    else:
        needs_squeeze = False

237
    output = resize_image_tensor(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
238
239
240
241
242

    if needs_squeeze:
        output = output.squeeze(0)

    return output
243
244


245
def resize_bounding_box(
246
    bounding_box: torch.Tensor, spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
247
) -> Tuple[torch.Tensor, Tuple[int, int]]:
248
249
    old_height, old_width = spatial_size
    new_height, new_width = _compute_resized_output_size(spatial_size, size=size, max_size=max_size)
250
251
252
253

    if (new_height, new_width) == (old_height, old_width):
        return bounding_box, spatial_size

254
255
256
    w_ratio = new_width / old_width
    h_ratio = new_height / old_height
    ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_box.device)
257
    return (
258
        bounding_box.mul(ratios).to(bounding_box.dtype),
259
260
        (new_height, new_width),
    )
261
262


263
264
265
def resize_video(
    video: torch.Tensor,
    size: List[int],
266
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
267
    max_size: Optional[int] = None,
268
    antialias: Optional[Union[str, bool]] = "warn",
269
270
271
272
) -> torch.Tensor:
    return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)


273
def resize(
Philip Meier's avatar
Philip Meier committed
274
    inpt: datapoints._InputTypeJIT,
275
    size: List[int],
276
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
277
    max_size: Optional[int] = None,
278
    antialias: Optional[Union[str, bool]] = "warn",
Philip Meier's avatar
Philip Meier committed
279
) -> datapoints._InputTypeJIT:
280
281
    if not torch.jit.is_scripting():
        _log_api_usage_once(resize)
282
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
283
        return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
284
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
285
        return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
286
    elif isinstance(inpt, PIL.Image.Image):
287
        if antialias is False:
288
289
            warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
        return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size)
290
291
    else:
        raise TypeError(
292
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
293
294
            f"but got {type(inpt)} instead."
        )
295
296


297
def _affine_parse_args(
298
    angle: Union[int, float],
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
    translate: List[float],
    scale: float,
    shear: List[float],
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    center: Optional[List[float]] = None,
) -> Tuple[float, List[float], List[float], Optional[List[float]]]:
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if not isinstance(translate, (list, tuple)):
        raise TypeError("Argument translate should be a sequence")

    if len(translate) != 2:
        raise ValueError("Argument translate should be a sequence of length 2")

    if scale <= 0.0:
        raise ValueError("Argument scale should be positive")

    if not isinstance(shear, (numbers.Number, (list, tuple))):
        raise TypeError("Shear should be either a single value or a sequence of two values")

    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")

    if isinstance(angle, int):
        angle = float(angle)

    if isinstance(translate, tuple):
        translate = list(translate)

    if isinstance(shear, numbers.Number):
        shear = [shear, 0.0]

    if isinstance(shear, tuple):
        shear = list(shear)

    if len(shear) == 1:
        shear = [shear[0], shear[0]]

    if len(shear) != 2:
        raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")

341
342
343
344
345
    if center is not None:
        if not isinstance(center, (list, tuple)):
            raise TypeError("Argument center should be a sequence")
        else:
            center = [float(c) for c in center]
346
347
348
349

    return angle, translate, shear, center


350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
def _get_inverse_affine_matrix(
    center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
) -> List[float]:
    # Helper method to compute inverse matrix for affine transformation

    # Pillow requires inverse affine transformation matrix:
    # Affine matrix is : M = T * C * RotateScaleShear * C^-1
    #
    # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
    #       C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
    #       RotateScaleShear is rotation with scale and shear matrix
    #
    #       RotateScaleShear(a, s, (sx, sy)) =
    #       = R(a) * S(s) * SHy(sy) * SHx(sx)
    #       = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
    #         [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
    #         [ 0                    , 0                                      , 1 ]
    # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
    # SHx(s) = [1, -tan(s)] and SHy(s) = [1      , 0]
    #          [0, 1      ]              [-tan(s), 1]
    #
    # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1

    rot = math.radians(angle)
    sx = math.radians(shear[0])
    sy = math.radians(shear[1])

    cx, cy = center
    tx, ty = translate

    # Cached results
    cos_sy = math.cos(sy)
    tan_sx = math.tan(sx)
    rot_minus_sy = rot - sy
    cx_plus_tx = cx + tx
    cy_plus_ty = cy + ty

    # Rotate Scale Shear (RSS) without scaling
    a = math.cos(rot_minus_sy) / cos_sy
    b = -(a * tan_sx + math.sin(rot))
    c = math.sin(rot_minus_sy) / cos_sy
    d = math.cos(rot) - c * tan_sx

    if inverted:
        # Inverted rotation matrix with scale and shear
        # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
        matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0]
        # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
        # and then apply center translation: C * RSS^-1 * C^-1 * T^-1
        matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty
        matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty
    else:
        matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0]
        # Apply inverse of center translation: RSS * C^-1
        # and then apply translation and center : T * C * RSS * C^-1
        matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy
        matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy

    return matrix


def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
    # Inspired of PIL implementation:
    # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054

    # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
    # Points are shifted due to affine matrix torch convention about
    # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
    half_w = 0.5 * w
    half_h = 0.5 * h
    pts = torch.tensor(
        [
            [-half_w, -half_h, 1.0],
            [-half_w, half_h, 1.0],
            [half_w, half_h, 1.0],
            [half_w, -half_h, 1.0],
        ]
    )
    theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
    new_pts = torch.matmul(pts, theta.T)
    min_vals, max_vals = new_pts.aminmax(dim=0)

    # shift points to [0, w] and [0, h] interval to match PIL results
    halfs = torch.tensor((half_w, half_h))
    min_vals.add_(halfs)
    max_vals.add_(halfs)

    # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
    tol = 1e-4
    inv_tol = 1.0 / tol
    cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_()
    cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_()
    size = cmax.sub_(cmin)
    return int(size[0]), int(size[1])  # w, h


def _apply_grid_transform(
Philip Meier's avatar
Philip Meier committed
447
    img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints._FillTypeJIT
448
449
) -> torch.Tensor:

450
451
452
453
    # We are using context knowledge that grid should have float dtype
    fp = img.dtype == grid.dtype
    float_img = img if fp else img.to(grid.dtype)

454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
    shape = float_img.shape
    if shape[0] > 1:
        # Apply same grid to a batch of images
        grid = grid.expand(shape[0], -1, -1, -1)

    # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
    if fill is not None:
        mask = torch.ones((shape[0], 1, shape[2], shape[3]), dtype=float_img.dtype, device=float_img.device)
        float_img = torch.cat((float_img, mask), dim=1)

    float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False)

    # Fill with required color
    if fill is not None:
        float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
        mask = mask.expand_as(float_img)
470
        fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)]  # type: ignore[arg-type]
471
472
473
474
475
476
477
478
479
        fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
        if mode == "nearest":
            bool_mask = mask < 0.5
            float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
        else:  # 'bilinear'
            # The following is mathematically equivalent to:
            # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
            float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)

480
481
482
    img = float_img.round_().to(img.dtype) if not fp else float_img

    return img
483
484
485
486
487
488


def _assert_grid_transform_inputs(
    image: torch.Tensor,
    matrix: Optional[List[float]],
    interpolation: str,
Philip Meier's avatar
Philip Meier committed
489
    fill: datapoints._FillTypeJIT,
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
    supported_interpolation_modes: List[str],
    coeffs: Optional[List[float]] = None,
) -> None:
    if matrix is not None:
        if not isinstance(matrix, list):
            raise TypeError("Argument matrix should be a list")
        elif len(matrix) != 6:
            raise ValueError("Argument matrix should have 6 float values")

    if coeffs is not None and len(coeffs) != 8:
        raise ValueError("Argument coeffs should have 8 float values")

    if fill is not None:
        if isinstance(fill, (tuple, list)):
            length = len(fill)
            num_channels = image.shape[-3]
            if length > 1 and length != num_channels:
                raise ValueError(
                    "The number of elements in 'fill' cannot broadcast to match the number of "
                    f"channels of the image ({length} != {num_channels})"
                )
        elif not isinstance(fill, (int, float)):
            raise ValueError("Argument fill should be either int, float, tuple or list")

    if interpolation not in supported_interpolation_modes:
        raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")


def _affine_grid(
    theta: torch.Tensor,
    w: int,
    h: int,
    ow: int,
    oh: int,
) -> torch.Tensor:
    # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
    # AffineGridGenerator.cpp#L18
    # Difference with AffineGridGenerator is that:
    # 1) we normalize grid values after applying theta
    # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
    dtype = theta.dtype
    device = theta.device

    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
    x_grid = torch.linspace((1.0 - ow) * 0.5, (ow - 1.0) * 0.5, steps=ow, device=device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace((1.0 - oh) * 0.5, (oh - 1.0) * 0.5, steps=oh, device=device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
    base_grid[..., 2].fill_(1)

    rescaled_theta = theta.transpose(1, 2).div_(torch.tensor([0.5 * w, 0.5 * h], dtype=dtype, device=device))
    output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
    return output_grid.view(1, oh, ow, 2)


545
def affine_image_tensor(
546
    image: torch.Tensor,
547
    angle: Union[int, float],
548
549
550
    translate: List[float],
    scale: float,
    shear: List[float],
551
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
552
    fill: datapoints._FillTypeJIT = None,
553
554
    center: Optional[List[float]] = None,
) -> torch.Tensor:
555
556
    interpolation = _check_interpolation(interpolation)

557
558
    if image.numel() == 0:
        return image
559

560
    shape = image.shape
561
    ndim = image.ndim
562

563
564
565
566
567
568
569
570
571
572
    if ndim > 4:
        image = image.reshape((-1,) + shape[-3:])
        needs_unsquash = True
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
    else:
        needs_unsquash = False

    height, width = shape[-2:]
573
574
575
576
577
    angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

    center_f = [0.0, 0.0]
    if center is not None:
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
578
        center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
579

580
    translate_f = [float(t) for t in translate]
581
582
    matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)

583
584
    _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])

585
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
586
587
    theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
    grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
588
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
589
590
591
592
593

    if needs_unsquash:
        output = output.reshape(shape)

    return output
594
595


596
@torch.jit.unused
597
def affine_image_pil(
598
    image: PIL.Image.Image,
599
    angle: Union[int, float],
600
601
602
    translate: List[float],
    scale: float,
    shear: List[float],
603
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
604
    fill: datapoints._FillTypeJIT = None,
605
606
    center: Optional[List[float]] = None,
) -> PIL.Image.Image:
607
    interpolation = _check_interpolation(interpolation)
608
609
610
611
612
613
    angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

    # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
    # it is visually better to estimate the center without 0.5 offset
    # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
    if center is None:
614
        height, width = get_spatial_size_image_pil(image)
615
616
617
        center = [width * 0.5, height * 0.5]
    matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)

618
    return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
619
620


621
def _affine_bounding_box_with_expand(
622
    bounding_box: torch.Tensor,
623
    format: datapoints.BoundingBoxFormat,
624
    spatial_size: Tuple[int, int],
625
626
627
628
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
629
    center: Optional[List[float]] = None,
630
    expand: bool = False,
631
) -> Tuple[torch.Tensor, Tuple[int, int]]:
632
633
634
    if bounding_box.numel() == 0:
        return bounding_box, spatial_size

635
636
637
638
639
640
641
642
643
644
645
    original_shape = bounding_box.shape
    original_dtype = bounding_box.dtype
    bounding_box = bounding_box.clone() if bounding_box.is_floating_point() else bounding_box.float()
    dtype = bounding_box.dtype
    device = bounding_box.device
    bounding_box = (
        convert_format_bounding_box(
            bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
        )
    ).reshape(-1, 4)

646
647
648
    angle, translate, shear, center = _affine_parse_args(
        angle, translate, scale, shear, InterpolationMode.NEAREST, center
    )
649

650
    if center is None:
651
        height, width = spatial_size
652
653
        center = [width * 0.5, height * 0.5]

654
655
656
657
658
659
660
    affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
    transposed_affine_matrix = (
        torch.tensor(
            affine_vector,
            dtype=dtype,
            device=device,
        )
661
        .reshape(2, 3)
662
663
        .T
    )
664
665
666
667
    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
668
    points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
669
    points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
670
    # 2) Now let's transform the points using affine matrix
671
    transformed_points = torch.matmul(points, transposed_affine_matrix)
672
673
    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
674
    transformed_points = transformed_points.reshape(-1, 4, 2)
675
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
676
    out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
677
678
679
680

    if expand:
        # Compute minimum point for transformed image frame:
        # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
681
        height, width = spatial_size
682
683
684
        points = torch.tensor(
            [
                [0.0, 0.0, 1.0],
685
686
687
                [0.0, float(height), 1.0],
                [float(width), float(height), 1.0],
                [float(width), 0.0, 1.0],
688
689
690
691
            ],
            dtype=dtype,
            device=device,
        )
692
        new_points = torch.matmul(points, transposed_affine_matrix)
693
        tr = torch.amin(new_points, dim=0, keepdim=True)
694
        # Translate bounding boxes
695
        out_bboxes.sub_(tr.repeat((1, 2)))
696
697
        # Estimate meta-data for image with inverted=True and with center=[0,0]
        affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
698
        new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
699
        spatial_size = (new_height, new_width)
700

701
702
703
704
705
706
707
    out_bboxes = clamp_bounding_box(out_bboxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size)
    out_bboxes = convert_format_bounding_box(
        out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
    ).reshape(original_shape)

    out_bboxes = out_bboxes.to(original_dtype)
    return out_bboxes, spatial_size
708
709
710
711


def affine_bounding_box(
    bounding_box: torch.Tensor,
712
    format: datapoints.BoundingBoxFormat,
713
    spatial_size: Tuple[int, int],
714
    angle: Union[int, float],
715
716
717
718
719
    translate: List[float],
    scale: float,
    shear: List[float],
    center: Optional[List[float]] = None,
) -> torch.Tensor:
720
721
722
723
724
725
726
727
728
729
730
731
    out_box, _ = _affine_bounding_box_with_expand(
        bounding_box,
        format=format,
        spatial_size=spatial_size,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        center=center,
        expand=False,
    )
    return out_box
732
733


734
735
def affine_mask(
    mask: torch.Tensor,
736
    angle: Union[int, float],
737
738
739
    translate: List[float],
    scale: float,
    shear: List[float],
Philip Meier's avatar
Philip Meier committed
740
    fill: datapoints._FillTypeJIT = None,
741
742
    center: Optional[List[float]] = None,
) -> torch.Tensor:
743
744
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
745
746
747
748
749
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = affine_image_tensor(
750
        mask,
751
752
753
754
755
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=InterpolationMode.NEAREST,
756
        fill=fill,
757
758
759
        center=center,
    )

760
761
762
763
764
    if needs_squeeze:
        output = output.squeeze(0)

    return output

765

766
767
768
769
770
771
def affine_video(
    video: torch.Tensor,
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
772
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
773
    fill: datapoints._FillTypeJIT = None,
774
775
776
777
778
779
780
781
782
783
784
785
786
787
    center: Optional[List[float]] = None,
) -> torch.Tensor:
    return affine_image_tensor(
        video,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=interpolation,
        fill=fill,
        center=center,
    )


788
def affine(
Philip Meier's avatar
Philip Meier committed
789
    inpt: datapoints._InputTypeJIT,
790
    angle: Union[int, float],
791
792
793
    translate: List[float],
    scale: float,
    shear: List[float],
794
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
795
    fill: datapoints._FillTypeJIT = None,
796
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
797
) -> datapoints._InputTypeJIT:
798
799
800
    if not torch.jit.is_scripting():
        _log_api_usage_once(affine)

801
    # TODO: consider deprecating integers from angle and shear on the future
802
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
803
        return affine_image_tensor(
804
805
806
807
808
809
810
811
812
            inpt,
            angle,
            translate=translate,
            scale=scale,
            shear=shear,
            interpolation=interpolation,
            fill=fill,
            center=center,
        )
813
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
814
815
816
        return inpt.affine(
            angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center
        )
817
    elif isinstance(inpt, PIL.Image.Image):
818
        return affine_image_pil(
819
820
821
822
823
824
825
826
827
            inpt,
            angle,
            translate=translate,
            scale=scale,
            shear=shear,
            interpolation=interpolation,
            fill=fill,
            center=center,
        )
828
829
    else:
        raise TypeError(
830
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
831
832
            f"but got {type(inpt)} instead."
        )
833
834


835
def rotate_image_tensor(
836
    image: torch.Tensor,
837
    angle: float,
838
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
839
840
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
841
    fill: datapoints._FillTypeJIT = None,
842
) -> torch.Tensor:
843
844
    interpolation = _check_interpolation(interpolation)

845
846
    shape = image.shape
    num_channels, height, width = shape[-3:]
847

848
849
    center_f = [0.0, 0.0]
    if center is not None:
850
        if expand:
851
            # TODO: Do we actually want to warn, or just document this?
852
            warnings.warn("The provided center argument has no effect on the result if expand is True")
853
854
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
        center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
855
856
857
858

    # due to current incoherence of rotation angle direction between affine and rotate implementations
    # we need to set -angle.
    matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
859

860
    if image.numel() > 0:
861
862
863
864
865
        image = image.reshape(-1, num_channels, height, width)

        _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])

        ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
866
        dtype = image.dtype if torch.is_floating_point(image) else torch.float32
867
868
        theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
        grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh)
869
        output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
870
871

        new_height, new_width = output.shape[-2:]
872
    else:
873
874
        output = image
        new_width, new_height = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
875

876
    return output.reshape(shape[:-3] + (num_channels, new_height, new_width))
877
878


879
@torch.jit.unused
880
def rotate_image_pil(
881
    image: PIL.Image.Image,
882
    angle: float,
883
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
884
885
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
886
    fill: datapoints._FillTypeJIT = None,
887
) -> PIL.Image.Image:
888
889
    interpolation = _check_interpolation(interpolation)

890
    if center is not None and expand:
891
        warnings.warn("The provided center argument has no effect on the result if expand is True")
892
893
        center = None

894
    return _FP.rotate(
895
        image, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
896
897
898
    )


899
900
def rotate_bounding_box(
    bounding_box: torch.Tensor,
901
    format: datapoints.BoundingBoxFormat,
902
    spatial_size: Tuple[int, int],
903
904
905
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
906
) -> Tuple[torch.Tensor, Tuple[int, int]]:
907
908
909
910
    if center is not None and expand:
        warnings.warn("The provided center argument has no effect on the result if expand is True")
        center = None

911
    return _affine_bounding_box_with_expand(
912
        bounding_box,
913
914
        format=format,
        spatial_size=spatial_size,
915
916
917
918
919
920
921
        angle=-angle,
        translate=[0.0, 0.0],
        scale=1.0,
        shear=[0.0, 0.0],
        center=center,
        expand=expand,
    )
922
923


924
925
def rotate_mask(
    mask: torch.Tensor,
926
927
928
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
929
    fill: datapoints._FillTypeJIT = None,
930
) -> torch.Tensor:
931
932
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
933
934
935
936
937
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = rotate_image_tensor(
938
        mask,
939
940
941
        angle=angle,
        expand=expand,
        interpolation=InterpolationMode.NEAREST,
942
        fill=fill,
943
944
945
        center=center,
    )

946
947
948
949
950
    if needs_squeeze:
        output = output.squeeze(0)

    return output

951

952
953
954
def rotate_video(
    video: torch.Tensor,
    angle: float,
955
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
956
957
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
958
    fill: datapoints._FillTypeJIT = None,
959
960
961
962
) -> torch.Tensor:
    return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)


963
def rotate(
Philip Meier's avatar
Philip Meier committed
964
    inpt: datapoints._InputTypeJIT,
965
    angle: float,
966
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
967
968
    expand: bool = False,
    center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
969
970
    fill: datapoints._FillTypeJIT = None,
) -> datapoints._InputTypeJIT:
971
972
973
    if not torch.jit.is_scripting():
        _log_api_usage_once(rotate)

974
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
975
        return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
976
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
977
        return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
978
    elif isinstance(inpt, PIL.Image.Image):
979
        return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
980
981
    else:
        raise TypeError(
982
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
983
984
            f"but got {type(inpt)} instead."
        )
985
986


987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
    if isinstance(padding, int):
        pad_left = pad_right = pad_top = pad_bottom = padding
    elif isinstance(padding, (tuple, list)):
        if len(padding) == 1:
            pad_left = pad_right = pad_top = pad_bottom = padding[0]
        elif len(padding) == 2:
            pad_left = pad_right = padding[0]
            pad_top = pad_bottom = padding[1]
        elif len(padding) == 4:
            pad_left = padding[0]
            pad_top = padding[1]
            pad_right = padding[2]
            pad_bottom = padding[3]
        else:
            raise ValueError(
                f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
            )
    else:
        raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}")

    return [pad_left, pad_right, pad_top, pad_bottom]
1009

1010

1011
def pad_image_tensor(
1012
    image: torch.Tensor,
1013
1014
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1015
1016
    padding_mode: str = "constant",
) -> torch.Tensor:
1017
1018
1019
1020
1021
    # Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses
    # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
    # internally.
    torch_padding = _parse_pad_padding(padding)

1022
    if padding_mode not in ("constant", "edge", "reflect", "symmetric"):
1023
1024
1025
1026
1027
        raise ValueError(
            f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
            f"but got `'{padding_mode}'`."
        )

1028
    if fill is None:
1029
1030
1031
1032
1033
1034
        fill = 0

    if isinstance(fill, (int, float)):
        return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
    elif len(fill) == 1:
        return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode)
1035
    else:
1036
        return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
1037
1038
1039


def _pad_with_scalar_fill(
1040
    image: torch.Tensor,
1041
1042
1043
    torch_padding: List[int],
    fill: Union[int, float],
    padding_mode: str,
1044
) -> torch.Tensor:
1045
1046
    shape = image.shape
    num_channels, height, width = shape[-3:]
1047

1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
    batch_size = 1
    for s in shape[:-3]:
        batch_size *= s

    image = image.reshape(batch_size, num_channels, height, width)

    if padding_mode == "edge":
        # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
        # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
        # name.
        padding_mode = "replicate"

    if padding_mode == "constant":
        image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
    elif padding_mode in ("reflect", "replicate"):
        # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
        # TODO: See https://github.com/pytorch/pytorch/issues/40763
        dtype = image.dtype
        if not image.is_floating_point():
            needs_cast = True
            image = image.to(torch.float32)
        else:
            needs_cast = False
1071

1072
1073
1074
1075
1076
        image = torch_pad(image, torch_padding, mode=padding_mode)

        if needs_cast:
            image = image.to(dtype)
    else:  # padding_mode == "symmetric"
1077
        image = _pad_symmetric(image, torch_padding)
1078
1079

    new_height, new_width = image.shape[-2:]
1080

1081
    return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
1082
1083


1084
# TODO: This should be removed once torch_pad supports non-scalar padding values
1085
def _pad_with_vector_fill(
1086
    image: torch.Tensor,
1087
    torch_padding: List[int],
1088
    fill: List[float],
1089
    padding_mode: str,
1090
1091
1092
1093
) -> torch.Tensor:
    if padding_mode != "constant":
        raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")

1094
1095
    output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
    left, right, top, bottom = torch_padding
1096
    fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108

    if top > 0:
        output[..., :top, :] = fill
    if left > 0:
        output[..., :, :left] = fill
    if bottom > 0:
        output[..., -bottom:, :] = fill
    if right > 0:
        output[..., :, -right:] = fill
    return output


1109
1110
1111
pad_image_pil = _FP.pad


1112
1113
def pad_mask(
    mask: torch.Tensor,
1114
1115
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1116
1117
    padding_mode: str = "constant",
) -> torch.Tensor:
1118
1119
1120
    if fill is None:
        fill = 0

1121
    if isinstance(fill, (tuple, list)):
1122
1123
        raise ValueError("Non-scalar fill value is not supported")

1124
1125
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1126
1127
1128
1129
        needs_squeeze = True
    else:
        needs_squeeze = False

1130
    output = pad_image_tensor(mask, padding=padding, fill=fill, padding_mode=padding_mode)
1131
1132
1133
1134
1135

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1136
1137


1138
def pad_bounding_box(
vfdev's avatar
vfdev committed
1139
    bounding_box: torch.Tensor,
1140
    format: datapoints.BoundingBoxFormat,
1141
    spatial_size: Tuple[int, int],
1142
    padding: List[int],
vfdev's avatar
vfdev committed
1143
    padding_mode: str = "constant",
1144
) -> Tuple[torch.Tensor, Tuple[int, int]]:
vfdev's avatar
vfdev committed
1145
1146
1147
1148
    if padding_mode not in ["constant"]:
        # TODO: add support of other padding modes
        raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")

1149
    left, right, top, bottom = _parse_pad_padding(padding)
1150

1151
    if format == datapoints.BoundingBoxFormat.XYXY:
1152
1153
1154
1155
        pad = [left, top, left, top]
    else:
        pad = [left, top, 0, 0]
    bounding_box = bounding_box + torch.tensor(pad, dtype=bounding_box.dtype, device=bounding_box.device)
1156

1157
    height, width = spatial_size
1158
1159
    height += top + bottom
    width += left + right
1160
    spatial_size = (height, width)
1161

1162
    return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size
1163
1164


1165
1166
def pad_video(
    video: torch.Tensor,
1167
1168
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1169
1170
1171
1172
1173
    padding_mode: str = "constant",
) -> torch.Tensor:
    return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode)


1174
def pad(
Philip Meier's avatar
Philip Meier committed
1175
    inpt: datapoints._InputTypeJIT,
1176
1177
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1178
    padding_mode: str = "constant",
Philip Meier's avatar
Philip Meier committed
1179
) -> datapoints._InputTypeJIT:
1180
1181
1182
    if not torch.jit.is_scripting():
        _log_api_usage_once(pad)

1183
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1184
1185
        return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)

1186
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1187
        return inpt.pad(padding, fill=fill, padding_mode=padding_mode)
1188
    elif isinstance(inpt, PIL.Image.Image):
1189
        return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
1190
1191
    else:
        raise TypeError(
1192
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1193
1194
            f"but got {type(inpt)} instead."
        )
1195
1196


1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
    h, w = image.shape[-2:]

    right = left + width
    bottom = top + height

    if left < 0 or top < 0 or right > w or bottom > h:
        image = image[..., max(top, 0) : bottom, max(left, 0) : right]
        torch_padding = [
            max(min(right, 0) - left, 0),
            max(right - max(w, left), 0),
            max(min(bottom, 0) - top, 0),
            max(bottom - max(h, top), 0),
        ]
        return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
    return image[..., top:bottom, left:right]


1215
1216
1217
crop_image_pil = _FP.crop


1218
1219
def crop_bounding_box(
    bounding_box: torch.Tensor,
1220
    format: datapoints.BoundingBoxFormat,
1221
1222
    top: int,
    left: int,
1223
1224
1225
    height: int,
    width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1226

1227
    # Crop or implicit pad if left and/or top have negative values:
1228
    if format == datapoints.BoundingBoxFormat.XYXY:
1229
        sub = [left, top, left, top]
1230
    else:
1231
1232
1233
        sub = [left, top, 0, 0]

    bounding_box = bounding_box - torch.tensor(sub, dtype=bounding_box.dtype, device=bounding_box.device)
1234
    spatial_size = (height, width)
1235

1236
    return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size
1237
1238


1239
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = crop_image_tensor(mask, top, left, height, width)

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1252
1253


1254
1255
1256
1257
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
    return crop_image_tensor(video, top, left, height, width)


Philip Meier's avatar
Philip Meier committed
1258
def crop(inpt: datapoints._InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints._InputTypeJIT:
1259
1260
1261
    if not torch.jit.is_scripting():
        _log_api_usage_once(crop)

1262
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1263
        return crop_image_tensor(inpt, top, left, height, width)
1264
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1265
        return inpt.crop(top, left, height, width)
1266
    elif isinstance(inpt, PIL.Image.Image):
1267
        return crop_image_pil(inpt, top, left, height, width)
1268
1269
    else:
        raise TypeError(
1270
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1271
1272
            f"but got {type(inpt)} instead."
        )
1273
1274


1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
    # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
    # src/libImaging/Geometry.c#L394

    #
    # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #
    theta1 = torch.tensor(
        [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
    )
    theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)

    d = 0.5
    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
1290
    x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype)
1291
    base_grid[..., 0].copy_(x_grid)
1292
    y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1)
1293
1294
1295
1296
    base_grid[..., 1].copy_(y_grid)
    base_grid[..., 2].fill_(1)

    rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device))
1297
1298
1299
    shape = (1, oh * ow, 3)
    output_grid1 = base_grid.view(shape).bmm(rescaled_theta1)
    output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2))
1300
1301
1302
1303
1304

    output_grid = output_grid1.div_(output_grid2).sub_(1.0)
    return output_grid.view(1, oh, ow, 2)


1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
def _perspective_coefficients(
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]],
) -> List[float]:
    if coefficients is not None:
        if startpoints is not None and endpoints is not None:
            raise ValueError("The startpoints/endpoints and the coefficients shouldn't be defined concurrently.")
        elif len(coefficients) != 8:
            raise ValueError("Argument coefficients should have 8 float values")
        return coefficients
    elif startpoints is not None and endpoints is not None:
        return _get_perspective_coeffs(startpoints, endpoints)
    else:
        raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.")


1322
def perspective_image_tensor(
1323
    image: torch.Tensor,
1324
1325
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1326
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1327
    fill: datapoints._FillTypeJIT = None,
1328
    coefficients: Optional[List[float]] = None,
1329
) -> torch.Tensor:
1330
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1331
1332
    interpolation = _check_interpolation(interpolation)

1333
1334
1335
1336
    if image.numel() == 0:
        return image

    shape = image.shape
1337
    ndim = image.ndim
1338

1339
    if ndim > 4:
1340
        image = image.reshape((-1,) + shape[-3:])
1341
        needs_unsquash = True
1342
1343
1344
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
1345
1346
1347
    else:
        needs_unsquash = False

1348
    _assert_grid_transform_inputs(
1349
1350
1351
1352
1353
1354
1355
1356
        image,
        matrix=None,
        interpolation=interpolation.value,
        fill=fill,
        supported_interpolation_modes=["nearest", "bilinear"],
        coeffs=perspective_coeffs,
    )

1357
    oh, ow = shape[-2:]
1358
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1359
    grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
1360
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1361
1362

    if needs_unsquash:
1363
        output = output.reshape(shape)
1364
1365

    return output
1366
1367


1368
@torch.jit.unused
1369
def perspective_image_pil(
1370
    image: PIL.Image.Image,
1371
1372
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1373
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BICUBIC,
Philip Meier's avatar
Philip Meier committed
1374
    fill: datapoints._FillTypeJIT = None,
1375
    coefficients: Optional[List[float]] = None,
1376
) -> PIL.Image.Image:
1377
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1378
    interpolation = _check_interpolation(interpolation)
1379
    return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
1380
1381


1382
1383
def perspective_bounding_box(
    bounding_box: torch.Tensor,
1384
    format: datapoints.BoundingBoxFormat,
1385
    spatial_size: Tuple[int, int],
1386
1387
1388
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]] = None,
1389
) -> torch.Tensor:
1390
1391
1392
    if bounding_box.numel() == 0:
        return bounding_box

1393
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1394
1395

    original_shape = bounding_box.shape
1396
    # TODO: first cast to float if bbox is int64 before convert_format_bounding_box
1397
    bounding_box = (
1398
        convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
1399
    ).reshape(-1, 4)
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433

    dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
    device = bounding_box.device

    # perspective_coeffs are computed as endpoint -> start point
    # We have to invert perspective_coeffs for bboxes:
    # (x, y) - end point and (x_out, y_out) - start point
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # and we would like to get:
    # x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # and compute inv_coeffs in terms of coeffs

    denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3]
    if denom == 0:
        raise RuntimeError(
            f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. "
            f"Denominator is zero, denom={denom}"
        )

    inv_coeffs = [
        (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom,
        (-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom,
        (perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom,
        (-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom,
        (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom,
        (-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom,
        (-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom,
        (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
    ]

1434
1435
    theta1 = torch.tensor(
        [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
1436
1437
1438
1439
        dtype=dtype,
        device=device,
    )

1440
1441
1442
1443
    theta2 = torch.tensor(
        [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device
    )

1444
1445
1446
1447
    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
1448
    points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1449
1450
1451
1452
1453
    points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
    # 2) Now let's transform the points using perspective matrices
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)

1454
1455
    numer_points = torch.matmul(points, theta1.T)
    denom_points = torch.matmul(points, theta2.T)
1456
    transformed_points = numer_points.div_(denom_points)
1457
1458
1459

    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
1460
    transformed_points = transformed_points.reshape(-1, 4, 2)
1461
1462
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)

1463
1464
1465
1466
1467
    out_bboxes = clamp_bounding_box(
        torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype),
        format=datapoints.BoundingBoxFormat.XYXY,
        spatial_size=spatial_size,
    )
1468
1469
1470

    # out_bboxes should be of shape [N boxes, 4]

1471
    return convert_format_bounding_box(
1472
        out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
1473
    ).reshape(original_shape)
1474
1475


1476
1477
def perspective_mask(
    mask: torch.Tensor,
1478
1479
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
Philip Meier's avatar
Philip Meier committed
1480
    fill: datapoints._FillTypeJIT = None,
1481
    coefficients: Optional[List[float]] = None,
1482
) -> torch.Tensor:
1483
1484
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1485
1486
1487
1488
1489
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = perspective_image_tensor(
1490
        mask, startpoints, endpoints, interpolation=InterpolationMode.NEAREST, fill=fill, coefficients=coefficients
1491
    )
1492

1493
1494
1495
1496
1497
    if needs_squeeze:
        output = output.squeeze(0)

    return output

1498

1499
1500
def perspective_video(
    video: torch.Tensor,
1501
1502
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1503
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1504
    fill: datapoints._FillTypeJIT = None,
1505
    coefficients: Optional[List[float]] = None,
1506
) -> torch.Tensor:
1507
1508
1509
    return perspective_image_tensor(
        video, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
    )
1510
1511


1512
def perspective(
Philip Meier's avatar
Philip Meier committed
1513
    inpt: datapoints._InputTypeJIT,
1514
1515
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1516
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1517
    fill: datapoints._FillTypeJIT = None,
1518
    coefficients: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
1519
) -> datapoints._InputTypeJIT:
1520
1521
    if not torch.jit.is_scripting():
        _log_api_usage_once(perspective)
1522
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1523
1524
1525
        return perspective_image_tensor(
            inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
        )
1526
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1527
1528
1529
        return inpt.perspective(
            startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
        )
1530
    elif isinstance(inpt, PIL.Image.Image):
1531
1532
1533
        return perspective_image_pil(
            inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
        )
1534
1535
    else:
        raise TypeError(
1536
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1537
1538
            f"but got {type(inpt)} instead."
        )
1539
1540


1541
def elastic_image_tensor(
1542
    image: torch.Tensor,
1543
    displacement: torch.Tensor,
1544
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1545
    fill: datapoints._FillTypeJIT = None,
1546
) -> torch.Tensor:
1547
1548
    interpolation = _check_interpolation(interpolation)

1549
1550
1551
1552
    if image.numel() == 0:
        return image

    shape = image.shape
1553
    ndim = image.ndim
1554

1555
    device = image.device
1556
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1557
1558
1559
1560
1561
1562
1563

    # Patch: elastic transform should support (cpu,f16) input
    is_cpu_half = device.type == "cpu" and dtype == torch.float16
    if is_cpu_half:
        image = image.to(torch.float32)
        dtype = torch.float32

1564
1565
1566
    # We are aware that if input image dtype is uint8 and displacement is float64 then
    # displacement will be casted to float32 and all computations will be done with float32
    # We can fix this later if needed
1567

1568
1569
1570
1571
    expected_shape = (1,) + shape[-2:] + (2,)
    if expected_shape != displacement.shape:
        raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")

1572
    if ndim > 4:
1573
        image = image.reshape((-1,) + shape[-3:])
1574
        needs_unsquash = True
1575
1576
1577
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
1578
1579
1580
    else:
        needs_unsquash = False

1581
1582
    if displacement.dtype != dtype or displacement.device != device:
        displacement = displacement.to(dtype=dtype, device=device)
1583

1584
1585
1586
    image_height, image_width = shape[-2:]
    grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement)
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1587
1588

    if needs_unsquash:
1589
        output = output.reshape(shape)
1590

1591
1592
1593
    if is_cpu_half:
        output = output.to(torch.float16)

1594
    return output
1595
1596


1597
@torch.jit.unused
1598
def elastic_image_pil(
1599
    image: PIL.Image.Image,
1600
    displacement: torch.Tensor,
1601
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1602
    fill: datapoints._FillTypeJIT = None,
1603
) -> PIL.Image.Image:
1604
    t_img = pil_to_tensor(image)
1605
    output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill)
1606
    return to_pil_image(output, mode=image.mode)
1607
1608


1609
def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
1610
    sy, sx = size
1611
1612
    base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype)
    x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype)
1613
1614
    base_grid[..., 0].copy_(x_grid)

1615
    y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1)
1616
1617
1618
1619
1620
    base_grid[..., 1].copy_(y_grid)

    return base_grid


1621
1622
def elastic_bounding_box(
    bounding_box: torch.Tensor,
1623
    format: datapoints.BoundingBoxFormat,
1624
    spatial_size: Tuple[int, int],
1625
1626
    displacement: torch.Tensor,
) -> torch.Tensor:
1627
1628
1629
    if bounding_box.numel() == 0:
        return bounding_box

1630
    # TODO: add in docstring about approximation we are doing for grid inversion
1631
1632
1633
1634
1635
    device = bounding_box.device
    dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32

    if displacement.dtype != dtype or displacement.device != device:
        displacement = displacement.to(dtype=dtype, device=device)
1636
1637

    original_shape = bounding_box.shape
1638
    # TODO: first cast to float if bbox is int64 before convert_format_bounding_box
1639
    bounding_box = (
1640
        convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
1641
    ).reshape(-1, 4)
1642

1643
    id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype)
1644
1645
    # We construct an approximation of inverse grid as inv_grid = id_grid - displacement
    # This is not an exact inverse of the grid
1646
    inv_grid = id_grid.sub_(displacement)
1647
1648

    # Get points from bboxes
1649
    points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1650
1651
1652
1653
1654
    if points.is_floating_point():
        points = points.ceil_()
    index_xy = points.to(dtype=torch.long)
    index_x, index_y = index_xy[:, 0], index_xy[:, 1]

1655
    # Transform points:
1656
    t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype)
1657
    transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
1658

1659
    transformed_points = transformed_points.reshape(-1, 4, 2)
1660
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
1661
1662
1663
1664
1665
    out_bboxes = clamp_bounding_box(
        torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype),
        format=datapoints.BoundingBoxFormat.XYXY,
        spatial_size=spatial_size,
    )
1666

1667
    return convert_format_bounding_box(
1668
        out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
1669
    ).reshape(original_shape)
1670
1671


1672
1673
1674
def elastic_mask(
    mask: torch.Tensor,
    displacement: torch.Tensor,
Philip Meier's avatar
Philip Meier committed
1675
    fill: datapoints._FillTypeJIT = None,
1676
) -> torch.Tensor:
1677
1678
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1679
1680
1681
1682
        needs_squeeze = True
    else:
        needs_squeeze = False

1683
    output = elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill)
1684
1685
1686
1687
1688

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1689
1690


1691
1692
1693
def elastic_video(
    video: torch.Tensor,
    displacement: torch.Tensor,
1694
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1695
    fill: datapoints._FillTypeJIT = None,
1696
) -> torch.Tensor:
1697
    return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
1698
1699


1700
def elastic(
Philip Meier's avatar
Philip Meier committed
1701
    inpt: datapoints._InputTypeJIT,
1702
    displacement: torch.Tensor,
1703
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
1704
1705
    fill: datapoints._FillTypeJIT = None,
) -> datapoints._InputTypeJIT:
1706
1707
1708
    if not torch.jit.is_scripting():
        _log_api_usage_once(elastic)

1709
1710
1711
    if not isinstance(displacement, torch.Tensor):
        raise TypeError("Argument displacement should be a Tensor")

1712
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1713
        return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
1714
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1715
        return inpt.elastic(displacement, interpolation=interpolation, fill=fill)
1716
    elif isinstance(inpt, PIL.Image.Image):
1717
        return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill)
1718
1719
    else:
        raise TypeError(
1720
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1721
1722
            f"but got {type(inpt)} instead."
        )
1723
1724
1725
1726
1727


elastic_transform = elastic


1728
1729
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
    if isinstance(output_size, numbers.Number):
1730
1731
        s = int(output_size)
        return [s, s]
1732
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
1733
        return [output_size[0], output_size[0]]
1734
1735
    else:
        return list(output_size)
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754


def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]:
    return [
        (crop_width - image_width) // 2 if crop_width > image_width else 0,
        (crop_height - image_height) // 2 if crop_height > image_height else 0,
        (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
        (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
    ]


def _center_crop_compute_crop_anchor(
    crop_height: int, crop_width: int, image_height: int, image_width: int
) -> Tuple[int, int]:
    crop_top = int(round((image_height - crop_height) / 2.0))
    crop_left = int(round((image_width - crop_width) / 2.0))
    return crop_top, crop_left


1755
def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
1756
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1757
1758
1759
1760
    shape = image.shape
    if image.numel() == 0:
        return image.reshape(shape[:-2] + (crop_height, crop_width))
    image_height, image_width = shape[-2:]
1761
1762
1763

    if crop_height > image_height or crop_width > image_width:
        padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
1764
        image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0)
1765

1766
        image_height, image_width = image.shape[-2:]
1767
        if crop_width == image_width and crop_height == image_height:
1768
            return image
1769
1770

    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
1771
    return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]
1772
1773


1774
@torch.jit.unused
1775
def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
1776
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1777
    image_height, image_width = get_spatial_size_image_pil(image)
1778
1779
1780

    if crop_height > image_height or crop_width > image_width:
        padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
1781
        image = pad_image_pil(image, padding_ltrb, fill=0)
1782

1783
        image_height, image_width = get_spatial_size_image_pil(image)
1784
        if crop_width == image_width and crop_height == image_height:
1785
            return image
1786
1787

    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
1788
    return crop_image_pil(image, crop_top, crop_left, crop_height, crop_width)
1789
1790


1791
1792
def center_crop_bounding_box(
    bounding_box: torch.Tensor,
1793
    format: datapoints.BoundingBoxFormat,
1794
    spatial_size: Tuple[int, int],
1795
    output_size: List[int],
1796
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1797
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1798
    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *spatial_size)
1799
    return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width)
1800
1801


1802
1803
1804
def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1805
1806
1807
1808
        needs_squeeze = True
    else:
        needs_squeeze = False

1809
    output = center_crop_image_tensor(image=mask, output_size=output_size)
1810
1811
1812
1813
1814

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1815
1816


1817
1818
1819
1820
def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor:
    return center_crop_image_tensor(video, output_size)


Philip Meier's avatar
Philip Meier committed
1821
def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datapoints._InputTypeJIT:
1822
1823
1824
    if not torch.jit.is_scripting():
        _log_api_usage_once(center_crop)

1825
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1826
        return center_crop_image_tensor(inpt, output_size)
1827
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1828
        return inpt.center_crop(output_size)
1829
    elif isinstance(inpt, PIL.Image.Image):
1830
        return center_crop_image_pil(inpt, output_size)
1831
1832
    else:
        raise TypeError(
1833
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1834
1835
            f"but got {type(inpt)} instead."
        )
1836
1837


1838
def resized_crop_image_tensor(
1839
    image: torch.Tensor,
1840
1841
1842
1843
1844
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1845
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1846
    antialias: Optional[Union[str, bool]] = "warn",
1847
) -> torch.Tensor:
1848
1849
    image = crop_image_tensor(image, top, left, height, width)
    return resize_image_tensor(image, size, interpolation=interpolation, antialias=antialias)
1850
1851


1852
@torch.jit.unused
1853
def resized_crop_image_pil(
1854
    image: PIL.Image.Image,
1855
1856
1857
1858
1859
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1860
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1861
) -> PIL.Image.Image:
1862
1863
    image = crop_image_pil(image, top, left, height, width)
    return resize_image_pil(image, size, interpolation=interpolation)
1864
1865


1866
1867
def resized_crop_bounding_box(
    bounding_box: torch.Tensor,
1868
    format: datapoints.BoundingBoxFormat,
1869
1870
1871
1872
1873
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1874
1875
) -> Tuple[torch.Tensor, Tuple[int, int]]:
    bounding_box, _ = crop_bounding_box(bounding_box, format, top, left, height, width)
1876
    return resize_bounding_box(bounding_box, spatial_size=(height, width), size=size)
1877
1878


1879
def resized_crop_mask(
1880
1881
1882
1883
1884
1885
1886
    mask: torch.Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
) -> torch.Tensor:
1887
1888
    mask = crop_mask(mask, top, left, height, width)
    return resize_mask(mask, size)
1889
1890


1891
1892
1893
1894
1895
1896
1897
def resized_crop_video(
    video: torch.Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1898
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1899
    antialias: Optional[Union[str, bool]] = "warn",
1900
1901
1902
1903
1904
1905
) -> torch.Tensor:
    return resized_crop_image_tensor(
        video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
    )


1906
def resized_crop(
Philip Meier's avatar
Philip Meier committed
1907
    inpt: datapoints._InputTypeJIT,
1908
1909
1910
1911
1912
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
1913
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1914
    antialias: Optional[Union[str, bool]] = "warn",
Philip Meier's avatar
Philip Meier committed
1915
) -> datapoints._InputTypeJIT:
1916
1917
1918
    if not torch.jit.is_scripting():
        _log_api_usage_once(resized_crop)

1919
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
1920
1921
1922
        return resized_crop_image_tensor(
            inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
        )
1923
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
1924
        return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
1925
    elif isinstance(inpt, PIL.Image.Image):
1926
        return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation)
1927
1928
    else:
        raise TypeError(
1929
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
1930
1931
            f"but got {type(inpt)} instead."
        )
1932
1933


1934
1935
def _parse_five_crop_size(size: List[int]) -> List[int]:
    if isinstance(size, numbers.Number):
1936
1937
        s = int(size)
        size = [s, s]
1938
    elif isinstance(size, (tuple, list)) and len(size) == 1:
1939
1940
        s = size[0]
        size = [s, s]
1941
1942
1943
1944
1945
1946
1947
1948

    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")

    return size


def five_crop_image_tensor(
1949
    image: torch.Tensor, size: List[int]
1950
1951
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    crop_height, crop_width = _parse_five_crop_size(size)
1952
    image_height, image_width = image.shape[-2:]
1953
1954

    if crop_width > image_width or crop_height > image_height:
1955
        raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
1956

1957
1958
1959
1960
1961
    tl = crop_image_tensor(image, 0, 0, crop_height, crop_width)
    tr = crop_image_tensor(image, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop_image_tensor(image, image_height - crop_height, 0, crop_height, crop_width)
    br = crop_image_tensor(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
    center = center_crop_image_tensor(image, [crop_height, crop_width])
1962
1963
1964
1965

    return tl, tr, bl, br, center


1966
@torch.jit.unused
1967
def five_crop_image_pil(
1968
    image: PIL.Image.Image, size: List[int]
1969
1970
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
    crop_height, crop_width = _parse_five_crop_size(size)
1971
    image_height, image_width = get_spatial_size_image_pil(image)
1972
1973

    if crop_width > image_width or crop_height > image_height:
1974
        raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
1975

1976
1977
1978
1979
1980
    tl = crop_image_pil(image, 0, 0, crop_height, crop_width)
    tr = crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width)
    br = crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
    center = center_crop_image_pil(image, [crop_height, crop_width])
1981
1982
1983
1984

    return tl, tr, bl, br, center


1985
1986
1987
1988
1989
1990
def five_crop_video(
    video: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    return five_crop_image_tensor(video, size)


Philip Meier's avatar
Philip Meier committed
1991
ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]
1992
1993


1994
def five_crop(
1995
1996
    inpt: ImageOrVideoTypeJIT, size: List[int]
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
1997
1998
1999
    if not torch.jit.is_scripting():
        _log_api_usage_once(five_crop)

2000
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
2001
        return five_crop_image_tensor(inpt, size)
2002
    elif isinstance(inpt, datapoints.Image):
2003
        output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size)
2004
2005
        return tuple(datapoints.Image.wrap_like(inpt, item) for item in output)  # type: ignore[return-value]
    elif isinstance(inpt, datapoints.Video):
2006
        output = five_crop_video(inpt.as_subclass(torch.Tensor), size)
2007
        return tuple(datapoints.Video.wrap_like(inpt, item) for item in output)  # type: ignore[return-value]
2008
    elif isinstance(inpt, PIL.Image.Image):
2009
        return five_crop_image_pil(inpt, size)
2010
2011
    else:
        raise TypeError(
2012
            f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
2013
2014
            f"but got {type(inpt)} instead."
        )
2015
2016


Philip Meier's avatar
Philip Meier committed
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
def ten_crop_image_tensor(
    image: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
    non_flipped = five_crop_image_tensor(image, size)
2032
2033

    if vertical_flip:
2034
        image = vertical_flip_image_tensor(image)
2035
    else:
2036
        image = horizontal_flip_image_tensor(image)
2037

Philip Meier's avatar
Philip Meier committed
2038
    flipped = five_crop_image_tensor(image, size)
2039

Philip Meier's avatar
Philip Meier committed
2040
    return non_flipped + flipped
2041
2042


2043
@torch.jit.unused
Philip Meier's avatar
Philip Meier committed
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
def ten_crop_image_pil(
    image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
) -> Tuple[
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
]:
    non_flipped = five_crop_image_pil(image, size)
2059
2060

    if vertical_flip:
2061
        image = vertical_flip_image_pil(image)
2062
    else:
2063
        image = horizontal_flip_image_pil(image)
2064

Philip Meier's avatar
Philip Meier committed
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
    flipped = five_crop_image_pil(image, size)

    return non_flipped + flipped


def ten_crop_video(
    video: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
2084
2085
2086
2087
    return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip)


def ten_crop(
Philip Meier's avatar
Philip Meier committed
2088
    inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], size: List[int], vertical_flip: bool = False
Philip Meier's avatar
Philip Meier committed
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
) -> Tuple[
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
    ImageOrVideoTypeJIT,
]:
2101
2102
2103
    if not torch.jit.is_scripting():
        _log_api_usage_once(ten_crop)

2104
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
2105
        return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
2106
    elif isinstance(inpt, datapoints.Image):
2107
        output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
2108
        return tuple(datapoints.Image.wrap_like(inpt, item) for item in output)  # type: ignore[return-value]
2109
    elif isinstance(inpt, datapoints.Video):
2110
        output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
2111
        return tuple(datapoints.Video.wrap_like(inpt, item) for item in output)  # type: ignore[return-value]
2112
    elif isinstance(inpt, PIL.Image.Image):
2113
        return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
2114
2115
    else:
        raise TypeError(
2116
            f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
2117
2118
            f"but got {type(inpt)} instead."
        )