_geometry.py 84.3 KB
Newer Older
1
import math
2
import numbers
3
import warnings
4
from typing import Any, List, Optional, Sequence, Tuple, Union
5
6
7

import PIL.Image
import torch
8
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
9

10
from torchvision import datapoints
11
12
from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _pad_symmetric
13
from torchvision.transforms.functional import (
14
    _check_antialias,
15
    _compute_resized_output_size as __compute_resized_output_size,
16
    _get_perspective_coeffs,
17
    _interpolation_modes_from_int,
18
    InterpolationMode,
19
    pil_modes_mapping,
20
21
    pil_to_tensor,
    to_pil_image,
22
)
23

24
25
from torchvision.utils import _log_api_usage_once

Philip Meier's avatar
Philip Meier committed
26
from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil
27

28
29
30
31
32
33
34
from ._utils import (
    _FillTypeJIT,
    _get_kernel,
    _register_explicit_noop,
    _register_five_ten_crop_kernel,
    _register_kernel_internal,
)
35

36

37
38
39
40
41
42
43
44
45
46
47
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
    if isinstance(interpolation, int):
        interpolation = _interpolation_modes_from_int(interpolation)
    elif not isinstance(interpolation, InterpolationMode):
        raise ValueError(
            f"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, "
            f"but got {interpolation}."
        )
    return interpolation


48
def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor:
49
    if torch.jit.is_scripting():
50
        return horizontal_flip_image_tensor(inpt)
51
52
53
54
55

    _log_api_usage_once(horizontal_flip)

    kernel = _get_kernel(horizontal_flip, type(inpt))
    return kernel(inpt)
56
57


58
@_register_kernel_internal(horizontal_flip, torch.Tensor)
59
@_register_kernel_internal(horizontal_flip, datapoints.Image)
60
61
62
63
def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
    return image.flip(-1)


64
@_register_kernel_internal(horizontal_flip, PIL.Image.Image)
65
66
def horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
    return _FP.hflip(image)
67
68


69
@_register_kernel_internal(horizontal_flip, datapoints.Mask)
70
71
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
    return horizontal_flip_image_tensor(mask)
72
73


74
def horizontal_flip_bounding_boxes(
Philip Meier's avatar
Philip Meier committed
75
    bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, canvas_size: Tuple[int, int]
76
) -> torch.Tensor:
77
    shape = bounding_boxes.shape
78

79
    bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
80

81
    if format == datapoints.BoundingBoxFormat.XYXY:
Philip Meier's avatar
Philip Meier committed
82
        bounding_boxes[:, [2, 0]] = bounding_boxes[:, [0, 2]].sub_(canvas_size[1]).neg_()
83
    elif format == datapoints.BoundingBoxFormat.XYWH:
Philip Meier's avatar
Philip Meier committed
84
        bounding_boxes[:, 0].add_(bounding_boxes[:, 2]).sub_(canvas_size[1]).neg_()
85
    else:  # format == datapoints.BoundingBoxFormat.CXCYWH:
Philip Meier's avatar
Philip Meier committed
86
        bounding_boxes[:, 0].sub_(canvas_size[1]).neg_()
87

88
    return bounding_boxes.reshape(shape)
89
90


91
92
93
94
95
96
97
98
99
@_register_kernel_internal(horizontal_flip, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _horizontal_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> datapoints.BoundingBoxes:
    output = horizontal_flip_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
    )
    return datapoints.BoundingBoxes.wrap_like(inpt, output)


@_register_kernel_internal(horizontal_flip, datapoints.Video)
100
101
102
103
def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
    return horizontal_flip_image_tensor(video)


104
def vertical_flip(inpt: torch.Tensor) -> torch.Tensor:
105
    if torch.jit.is_scripting():
106
        return vertical_flip_image_tensor(inpt)
107
108
109
110
111

    _log_api_usage_once(vertical_flip)

    kernel = _get_kernel(vertical_flip, type(inpt))
    return kernel(inpt)
112
113


114
@_register_kernel_internal(vertical_flip, torch.Tensor)
115
@_register_kernel_internal(vertical_flip, datapoints.Image)
116
117
118
119
def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
    return image.flip(-2)


120
@_register_kernel_internal(vertical_flip, PIL.Image.Image)
Philip Meier's avatar
Philip Meier committed
121
122
def vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
    return _FP.vflip(image)
123
124


125
@_register_kernel_internal(vertical_flip, datapoints.Mask)
126
127
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
    return vertical_flip_image_tensor(mask)
128
129


130
def vertical_flip_bounding_boxes(
Philip Meier's avatar
Philip Meier committed
131
    bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, canvas_size: Tuple[int, int]
132
) -> torch.Tensor:
133
    shape = bounding_boxes.shape
134

135
    bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
136

137
    if format == datapoints.BoundingBoxFormat.XYXY:
Philip Meier's avatar
Philip Meier committed
138
        bounding_boxes[:, [1, 3]] = bounding_boxes[:, [3, 1]].sub_(canvas_size[0]).neg_()
139
    elif format == datapoints.BoundingBoxFormat.XYWH:
Philip Meier's avatar
Philip Meier committed
140
        bounding_boxes[:, 1].add_(bounding_boxes[:, 3]).sub_(canvas_size[0]).neg_()
141
    else:  # format == datapoints.BoundingBoxFormat.CXCYWH:
Philip Meier's avatar
Philip Meier committed
142
        bounding_boxes[:, 1].sub_(canvas_size[0]).neg_()
143

144
    return bounding_boxes.reshape(shape)
145
146


147
148
149
150
151
152
@_register_kernel_internal(vertical_flip, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _vertical_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> datapoints.BoundingBoxes:
    output = vertical_flip_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
    )
    return datapoints.BoundingBoxes.wrap_like(inpt, output)
153

154

155
156
157
@_register_kernel_internal(vertical_flip, datapoints.Video)
def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
    return vertical_flip_image_tensor(video)
158
159


160
161
162
163
164
165
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
hflip = horizontal_flip
vflip = vertical_flip


166
def _compute_resized_output_size(
Philip Meier's avatar
Philip Meier committed
167
    canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
168
169
170
) -> List[int]:
    if isinstance(size, int):
        size = [size]
171
172
173
174
175
    elif max_size is not None and len(size) != 1:
        raise ValueError(
            "max_size should only be passed if size specifies the length of the smaller edge, "
            "i.e. size should be an int or a sequence of length 1 in torchscript mode."
        )
Philip Meier's avatar
Philip Meier committed
176
    return __compute_resized_output_size(canvas_size, size=size, max_size=max_size)
177
178


179
def resize(
180
    inpt: torch.Tensor,
181
182
183
184
    size: List[int],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
    max_size: Optional[int] = None,
    antialias: Optional[Union[str, bool]] = "warn",
185
) -> torch.Tensor:
186
187
188
189
190
191
192
    if torch.jit.is_scripting():
        return resize_image_tensor(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)

    _log_api_usage_once(resize)

    kernel = _get_kernel(resize, type(inpt))
    return kernel(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
193
194


195
@_register_kernel_internal(resize, torch.Tensor)
196
@_register_kernel_internal(resize, datapoints.Image)
197
198
199
def resize_image_tensor(
    image: torch.Tensor,
    size: List[int],
200
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
201
    max_size: Optional[int] = None,
202
    antialias: Optional[Union[str, bool]] = "warn",
203
) -> torch.Tensor:
204
    interpolation = _check_interpolation(interpolation)
205
206
    antialias = _check_antialias(img=image, antialias=antialias, interpolation=interpolation)
    assert not isinstance(antialias, str)
207
    antialias = False if antialias is None else antialias
208
209
210
    align_corners: Optional[bool] = None
    if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC:
        align_corners = False
211
212
213
214
    else:
        # The default of antialias should be True from 0.17, so we don't warn or
        # error if other interpolation modes are used. This is documented.
        antialias = False
215

216
    shape = image.shape
217
    numel = image.numel()
218
    num_channels, old_height, old_width = shape[-3:]
vfdev's avatar
vfdev committed
219
    new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
220

221
222
    if (new_height, new_width) == (old_height, old_width):
        return image
223
    elif numel > 0:
224
        image = image.reshape(-1, num_channels, old_height, old_width)
225

226
        dtype = image.dtype
227
228
229
230
        acceptable_dtypes = [torch.float32, torch.float64]
        if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
            # uint8 dtype can be included for cpu and cuda input if nearest mode
            acceptable_dtypes.append(torch.uint8)
231
232
233
234
235
236
237
        elif image.device.type == "cpu":
            # uint8 dtype support for bilinear and bicubic is limited to cpu and
            # according to our benchmarks, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilinear
            if (interpolation == InterpolationMode.BILINEAR and "AVX2" in torch.backends.cpu.get_cpu_capability()) or (
                interpolation == InterpolationMode.BICUBIC
            ):
                acceptable_dtypes.append(torch.uint8)
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253

        strides = image.stride()
        if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
            # There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as
            # contiguous even though the input is un-ambiguously channels_last (https://github.com/pytorch/pytorch/issues/68430).
            # In particular this happens for the typical torchvision use-case of single CHW images where we fake the batch dim
            # to become 1CHW. Below, we restride those tensors to trick torch core into properly allocating the output as
            # channels_last, thus preserving the memory format of the input. This is not just for format consistency:
            # for uint8 bilinear images, this also avoids an extra copy (re-packing) of the output and saves time.
            # TODO: when https://github.com/pytorch/pytorch/issues/68430 is fixed (possibly by https://github.com/pytorch/pytorch/pull/100373),
            # we should be able to remove this hack.
            new_strides = list(strides)
            new_strides[0] = numel
            image = image.as_strided((1, num_channels, old_height, old_width), new_strides)

        need_cast = dtype not in acceptable_dtypes
254
255
256
257
        if need_cast:
            image = image.to(dtype=torch.float32)

        image = interpolate(
258
259
            image,
            size=[new_height, new_width],
260
261
            mode=interpolation.value,
            align_corners=align_corners,
262
263
            antialias=antialias,
        )
264

265
266
        if need_cast:
            if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
267
                # This path is hit on non-AVX archs, or on GPU.
268
                image = image.clamp_(min=0, max=255)
269
270
271
            if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
                image = image.round_()
            image = image.to(dtype=dtype)
272

273
    return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
274
275
276


def resize_image_pil(
277
    image: PIL.Image.Image,
278
    size: Union[Sequence[int], int],
279
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
280
281
    max_size: Optional[int] = None,
) -> PIL.Image.Image:
282
283
284
285
286
287
288
    old_height, old_width = image.height, image.width
    new_height, new_width = _compute_resized_output_size(
        (old_height, old_width),
        size=size,  # type: ignore[arg-type]
        max_size=max_size,
    )

289
    interpolation = _check_interpolation(interpolation)
290
291
292
293
294

    if (new_height, new_width) == (old_height, old_width):
        return image

    return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation])
295
296


297
298
299
300
301
302
303
304
305
306
307
308
309
@_register_kernel_internal(resize, PIL.Image.Image)
def _resize_image_pil_dispatch(
    image: PIL.Image.Image,
    size: Union[Sequence[int], int],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
    max_size: Optional[int] = None,
    antialias: Optional[Union[str, bool]] = "warn",
) -> PIL.Image.Image:
    if antialias is False:
        warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
    return resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size)


310
311
312
def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
313
314
315
316
        needs_squeeze = True
    else:
        needs_squeeze = False

317
    output = resize_image_tensor(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
318
319
320
321
322

    if needs_squeeze:
        output = output.squeeze(0)

    return output
323
324


325
326
327
328
329
330
331
332
@_register_kernel_internal(resize, datapoints.Mask, datapoint_wrapper=False)
def _resize_mask_dispatch(
    inpt: datapoints.Mask, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> datapoints.Mask:
    output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size)
    return datapoints.Mask.wrap_like(inpt, output)


333
def resize_bounding_boxes(
Philip Meier's avatar
Philip Meier committed
334
    bounding_boxes: torch.Tensor, canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
335
) -> Tuple[torch.Tensor, Tuple[int, int]]:
Philip Meier's avatar
Philip Meier committed
336
337
    old_height, old_width = canvas_size
    new_height, new_width = _compute_resized_output_size(canvas_size, size=size, max_size=max_size)
338
339

    if (new_height, new_width) == (old_height, old_width):
Philip Meier's avatar
Philip Meier committed
340
        return bounding_boxes, canvas_size
341

342
343
    w_ratio = new_width / old_width
    h_ratio = new_height / old_height
344
    ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_boxes.device)
345
    return (
346
        bounding_boxes.mul(ratios).to(bounding_boxes.dtype),
347
348
        (new_height, new_width),
    )
349
350


351
352
353
354
355
356
357
358
359
360
361
@_register_kernel_internal(resize, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _resize_bounding_boxes_dispatch(
    inpt: datapoints.BoundingBoxes, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> datapoints.BoundingBoxes:
    output, canvas_size = resize_bounding_boxes(
        inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size
    )
    return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)


@_register_kernel_internal(resize, datapoints.Video)
362
363
364
def resize_video(
    video: torch.Tensor,
    size: List[int],
365
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
366
    max_size: Optional[int] = None,
367
    antialias: Optional[Union[str, bool]] = "warn",
368
369
370
371
) -> torch.Tensor:
    return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)


372
def affine(
373
    inpt: torch.Tensor,
374
375
376
377
378
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
379
    fill: _FillTypeJIT = None,
380
    center: Optional[List[float]] = None,
381
) -> torch.Tensor:
382
    if torch.jit.is_scripting():
383
384
        return affine_image_tensor(
            inpt,
385
            angle=angle,
386
387
388
389
390
391
392
            translate=translate,
            scale=scale,
            shear=shear,
            interpolation=interpolation,
            fill=fill,
            center=center,
        )
393
394
395
396
397
398
399
400
401
402
403
404
405
406

    _log_api_usage_once(affine)

    kernel = _get_kernel(affine, type(inpt))
    return kernel(
        inpt,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=interpolation,
        fill=fill,
        center=center,
    )
407
408


409
def _affine_parse_args(
410
    angle: Union[int, float],
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
    translate: List[float],
    scale: float,
    shear: List[float],
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    center: Optional[List[float]] = None,
) -> Tuple[float, List[float], List[float], Optional[List[float]]]:
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if not isinstance(translate, (list, tuple)):
        raise TypeError("Argument translate should be a sequence")

    if len(translate) != 2:
        raise ValueError("Argument translate should be a sequence of length 2")

    if scale <= 0.0:
        raise ValueError("Argument scale should be positive")

    if not isinstance(shear, (numbers.Number, (list, tuple))):
        raise TypeError("Shear should be either a single value or a sequence of two values")

    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")

    if isinstance(angle, int):
        angle = float(angle)

    if isinstance(translate, tuple):
        translate = list(translate)

    if isinstance(shear, numbers.Number):
        shear = [shear, 0.0]

    if isinstance(shear, tuple):
        shear = list(shear)

    if len(shear) == 1:
        shear = [shear[0], shear[0]]

    if len(shear) != 2:
        raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")

453
454
455
456
457
    if center is not None:
        if not isinstance(center, (list, tuple)):
            raise TypeError("Argument center should be a sequence")
        else:
            center = [float(c) for c in center]
458
459
460
461

    return angle, translate, shear, center


462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
def _get_inverse_affine_matrix(
    center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
) -> List[float]:
    # Helper method to compute inverse matrix for affine transformation

    # Pillow requires inverse affine transformation matrix:
    # Affine matrix is : M = T * C * RotateScaleShear * C^-1
    #
    # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
    #       C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
    #       RotateScaleShear is rotation with scale and shear matrix
    #
    #       RotateScaleShear(a, s, (sx, sy)) =
    #       = R(a) * S(s) * SHy(sy) * SHx(sx)
    #       = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
    #         [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
    #         [ 0                    , 0                                      , 1 ]
    # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
    # SHx(s) = [1, -tan(s)] and SHy(s) = [1      , 0]
    #          [0, 1      ]              [-tan(s), 1]
    #
    # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1

    rot = math.radians(angle)
    sx = math.radians(shear[0])
    sy = math.radians(shear[1])

    cx, cy = center
    tx, ty = translate

    # Cached results
    cos_sy = math.cos(sy)
    tan_sx = math.tan(sx)
    rot_minus_sy = rot - sy
    cx_plus_tx = cx + tx
    cy_plus_ty = cy + ty

    # Rotate Scale Shear (RSS) without scaling
    a = math.cos(rot_minus_sy) / cos_sy
    b = -(a * tan_sx + math.sin(rot))
    c = math.sin(rot_minus_sy) / cos_sy
    d = math.cos(rot) - c * tan_sx

    if inverted:
        # Inverted rotation matrix with scale and shear
        # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
        matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0]
        # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
        # and then apply center translation: C * RSS^-1 * C^-1 * T^-1
        matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty
        matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty
    else:
        matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0]
        # Apply inverse of center translation: RSS * C^-1
        # and then apply translation and center : T * C * RSS * C^-1
        matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy
        matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy

    return matrix


def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
    # Inspired of PIL implementation:
    # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054

    # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
    # Points are shifted due to affine matrix torch convention about
    # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
    half_w = 0.5 * w
    half_h = 0.5 * h
    pts = torch.tensor(
        [
            [-half_w, -half_h, 1.0],
            [-half_w, half_h, 1.0],
            [half_w, half_h, 1.0],
            [half_w, -half_h, 1.0],
        ]
    )
    theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
    new_pts = torch.matmul(pts, theta.T)
    min_vals, max_vals = new_pts.aminmax(dim=0)

    # shift points to [0, w] and [0, h] interval to match PIL results
    halfs = torch.tensor((half_w, half_h))
    min_vals.add_(halfs)
    max_vals.add_(halfs)

    # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
    tol = 1e-4
    inv_tol = 1.0 / tol
    cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_()
    cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_()
    size = cmax.sub_(cmin)
    return int(size[0]), int(size[1])  # w, h


558
def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill: _FillTypeJIT) -> torch.Tensor:
559

560
561
562
563
    # We are using context knowledge that grid should have float dtype
    fp = img.dtype == grid.dtype
    float_img = img if fp else img.to(grid.dtype)

564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
    shape = float_img.shape
    if shape[0] > 1:
        # Apply same grid to a batch of images
        grid = grid.expand(shape[0], -1, -1, -1)

    # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
    if fill is not None:
        mask = torch.ones((shape[0], 1, shape[2], shape[3]), dtype=float_img.dtype, device=float_img.device)
        float_img = torch.cat((float_img, mask), dim=1)

    float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False)

    # Fill with required color
    if fill is not None:
        float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
        mask = mask.expand_as(float_img)
580
        fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)]  # type: ignore[arg-type]
581
582
583
584
585
586
587
588
589
        fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
        if mode == "nearest":
            bool_mask = mask < 0.5
            float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
        else:  # 'bilinear'
            # The following is mathematically equivalent to:
            # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
            float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)

590
591
592
    img = float_img.round_().to(img.dtype) if not fp else float_img

    return img
593
594
595
596
597
598


def _assert_grid_transform_inputs(
    image: torch.Tensor,
    matrix: Optional[List[float]],
    interpolation: str,
599
    fill: _FillTypeJIT,
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
    supported_interpolation_modes: List[str],
    coeffs: Optional[List[float]] = None,
) -> None:
    if matrix is not None:
        if not isinstance(matrix, list):
            raise TypeError("Argument matrix should be a list")
        elif len(matrix) != 6:
            raise ValueError("Argument matrix should have 6 float values")

    if coeffs is not None and len(coeffs) != 8:
        raise ValueError("Argument coeffs should have 8 float values")

    if fill is not None:
        if isinstance(fill, (tuple, list)):
            length = len(fill)
            num_channels = image.shape[-3]
            if length > 1 and length != num_channels:
                raise ValueError(
                    "The number of elements in 'fill' cannot broadcast to match the number of "
                    f"channels of the image ({length} != {num_channels})"
                )
        elif not isinstance(fill, (int, float)):
            raise ValueError("Argument fill should be either int, float, tuple or list")

    if interpolation not in supported_interpolation_modes:
        raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")


def _affine_grid(
    theta: torch.Tensor,
    w: int,
    h: int,
    ow: int,
    oh: int,
) -> torch.Tensor:
    # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
    # AffineGridGenerator.cpp#L18
    # Difference with AffineGridGenerator is that:
    # 1) we normalize grid values after applying theta
    # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
    dtype = theta.dtype
    device = theta.device

    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
    x_grid = torch.linspace((1.0 - ow) * 0.5, (ow - 1.0) * 0.5, steps=ow, device=device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace((1.0 - oh) * 0.5, (oh - 1.0) * 0.5, steps=oh, device=device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
    base_grid[..., 2].fill_(1)

    rescaled_theta = theta.transpose(1, 2).div_(torch.tensor([0.5 * w, 0.5 * h], dtype=dtype, device=device))
    output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
    return output_grid.view(1, oh, ow, 2)


655
@_register_kernel_internal(affine, torch.Tensor)
656
@_register_kernel_internal(affine, datapoints.Image)
657
def affine_image_tensor(
658
    image: torch.Tensor,
659
    angle: Union[int, float],
660
661
662
    translate: List[float],
    scale: float,
    shear: List[float],
663
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
664
    fill: _FillTypeJIT = None,
665
666
    center: Optional[List[float]] = None,
) -> torch.Tensor:
667
668
    interpolation = _check_interpolation(interpolation)

669
670
    if image.numel() == 0:
        return image
671

672
    shape = image.shape
673
    ndim = image.ndim
674

675
676
677
678
679
680
681
682
683
684
    if ndim > 4:
        image = image.reshape((-1,) + shape[-3:])
        needs_unsquash = True
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
    else:
        needs_unsquash = False

    height, width = shape[-2:]
685
686
687
688
689
    angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

    center_f = [0.0, 0.0]
    if center is not None:
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
690
        center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
691

692
    translate_f = [float(t) for t in translate]
693
694
    matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)

695
696
    _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])

697
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
698
699
    theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
    grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
700
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
701
702
703
704
705

    if needs_unsquash:
        output = output.reshape(shape)

    return output
706
707


708
@_register_kernel_internal(affine, PIL.Image.Image)
709
def affine_image_pil(
710
    image: PIL.Image.Image,
711
    angle: Union[int, float],
712
713
714
    translate: List[float],
    scale: float,
    shear: List[float],
715
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
716
    fill: _FillTypeJIT = None,
717
718
    center: Optional[List[float]] = None,
) -> PIL.Image.Image:
719
    interpolation = _check_interpolation(interpolation)
720
721
722
723
724
725
    angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

    # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
    # it is visually better to estimate the center without 0.5 offset
    # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
    if center is None:
Philip Meier's avatar
Philip Meier committed
726
        height, width = get_size_image_pil(image)
727
728
729
        center = [width * 0.5, height * 0.5]
    matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)

730
    return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
731
732


733
734
def _affine_bounding_boxes_with_expand(
    bounding_boxes: torch.Tensor,
735
    format: datapoints.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
736
    canvas_size: Tuple[int, int],
737
738
739
740
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
741
    center: Optional[List[float]] = None,
742
    expand: bool = False,
743
) -> Tuple[torch.Tensor, Tuple[int, int]]:
744
    if bounding_boxes.numel() == 0:
Philip Meier's avatar
Philip Meier committed
745
        return bounding_boxes, canvas_size
746
747
748
749
750
751
752
753
754

    original_shape = bounding_boxes.shape
    original_dtype = bounding_boxes.dtype
    bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
    dtype = bounding_boxes.dtype
    device = bounding_boxes.device
    bounding_boxes = (
        convert_format_bounding_boxes(
            bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
755
756
757
        )
    ).reshape(-1, 4)

758
759
760
    angle, translate, shear, center = _affine_parse_args(
        angle, translate, scale, shear, InterpolationMode.NEAREST, center
    )
761

762
    if center is None:
Philip Meier's avatar
Philip Meier committed
763
        height, width = canvas_size
764
765
        center = [width * 0.5, height * 0.5]

766
767
768
769
770
771
772
    affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
    transposed_affine_matrix = (
        torch.tensor(
            affine_vector,
            dtype=dtype,
            device=device,
        )
773
        .reshape(2, 3)
774
775
        .T
    )
776
777
778
779
    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
780
    points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
781
    points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
782
    # 2) Now let's transform the points using affine matrix
783
    transformed_points = torch.matmul(points, transposed_affine_matrix)
784
785
    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
786
    transformed_points = transformed_points.reshape(-1, 4, 2)
787
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
788
    out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
789
790
791
792

    if expand:
        # Compute minimum point for transformed image frame:
        # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
Philip Meier's avatar
Philip Meier committed
793
        height, width = canvas_size
794
795
796
        points = torch.tensor(
            [
                [0.0, 0.0, 1.0],
797
798
799
                [0.0, float(height), 1.0],
                [float(width), float(height), 1.0],
                [float(width), 0.0, 1.0],
800
801
802
803
            ],
            dtype=dtype,
            device=device,
        )
804
        new_points = torch.matmul(points, transposed_affine_matrix)
805
        tr = torch.amin(new_points, dim=0, keepdim=True)
806
        # Translate bounding boxes
807
        out_bboxes.sub_(tr.repeat((1, 2)))
808
809
        # Estimate meta-data for image with inverted=True and with center=[0,0]
        affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
810
        new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
Philip Meier's avatar
Philip Meier committed
811
        canvas_size = (new_height, new_width)
812

Philip Meier's avatar
Philip Meier committed
813
    out_bboxes = clamp_bounding_boxes(out_bboxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size)
814
    out_bboxes = convert_format_bounding_boxes(
815
816
817
818
        out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
    ).reshape(original_shape)

    out_bboxes = out_bboxes.to(original_dtype)
Philip Meier's avatar
Philip Meier committed
819
    return out_bboxes, canvas_size
820
821


822
823
def affine_bounding_boxes(
    bounding_boxes: torch.Tensor,
824
    format: datapoints.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
825
    canvas_size: Tuple[int, int],
826
    angle: Union[int, float],
827
828
829
830
831
    translate: List[float],
    scale: float,
    shear: List[float],
    center: Optional[List[float]] = None,
) -> torch.Tensor:
832
833
    out_box, _ = _affine_bounding_boxes_with_expand(
        bounding_boxes,
834
        format=format,
Philip Meier's avatar
Philip Meier committed
835
        canvas_size=canvas_size,
836
837
838
839
840
841
842
843
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        center=center,
        expand=False,
    )
    return out_box
844
845


846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
@_register_kernel_internal(affine, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _affine_bounding_boxes_dispatch(
    inpt: datapoints.BoundingBoxes,
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
    center: Optional[List[float]] = None,
    **kwargs,
) -> datapoints.BoundingBoxes:
    output = affine_bounding_boxes(
        inpt.as_subclass(torch.Tensor),
        format=inpt.format,
        canvas_size=inpt.canvas_size,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        center=center,
    )
    return datapoints.BoundingBoxes.wrap_like(inpt, output)


869
870
def affine_mask(
    mask: torch.Tensor,
871
    angle: Union[int, float],
872
873
874
    translate: List[float],
    scale: float,
    shear: List[float],
875
    fill: _FillTypeJIT = None,
876
877
    center: Optional[List[float]] = None,
) -> torch.Tensor:
878
879
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
880
881
882
883
884
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = affine_image_tensor(
885
        mask,
886
887
888
889
890
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=InterpolationMode.NEAREST,
891
        fill=fill,
892
893
894
        center=center,
    )

895
896
897
898
899
    if needs_squeeze:
        output = output.squeeze(0)

    return output

900

901
902
903
904
905
906
907
@_register_kernel_internal(affine, datapoints.Mask, datapoint_wrapper=False)
def _affine_mask_dispatch(
    inpt: datapoints.Mask,
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
908
    fill: _FillTypeJIT = None,
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
    center: Optional[List[float]] = None,
    **kwargs,
) -> datapoints.Mask:
    output = affine_mask(
        inpt.as_subclass(torch.Tensor),
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        fill=fill,
        center=center,
    )
    return datapoints.Mask.wrap_like(inpt, output)


@_register_kernel_internal(affine, datapoints.Video)
925
926
927
928
929
930
def affine_video(
    video: torch.Tensor,
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
931
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
932
    fill: _FillTypeJIT = None,
933
934
935
936
937
938
939
940
941
942
943
944
945
946
    center: Optional[List[float]] = None,
) -> torch.Tensor:
    return affine_image_tensor(
        video,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=interpolation,
        fill=fill,
        center=center,
    )


947
def rotate(
948
    inpt: torch.Tensor,
949
    angle: float,
950
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
951
    expand: bool = False,
952
    center: Optional[List[float]] = None,
953
954
    fill: _FillTypeJIT = None,
) -> torch.Tensor:
955
956
957
    if torch.jit.is_scripting():
        return rotate_image_tensor(
            inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center
958
        )
959

960
    _log_api_usage_once(rotate)
961

962
963
964
965
966
    kernel = _get_kernel(rotate, type(inpt))
    return kernel(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center)


@_register_kernel_internal(rotate, torch.Tensor)
967
@_register_kernel_internal(rotate, datapoints.Image)
968
def rotate_image_tensor(
969
    image: torch.Tensor,
970
    angle: float,
971
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
972
973
    expand: bool = False,
    center: Optional[List[float]] = None,
974
    fill: _FillTypeJIT = None,
975
) -> torch.Tensor:
976
977
    interpolation = _check_interpolation(interpolation)

978
979
    shape = image.shape
    num_channels, height, width = shape[-3:]
980

981
982
    center_f = [0.0, 0.0]
    if center is not None:
983
        if expand:
984
            # TODO: Do we actually want to warn, or just document this?
985
            warnings.warn("The provided center argument has no effect on the result if expand is True")
986
987
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
        center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
988
989
990
991

    # due to current incoherence of rotation angle direction between affine and rotate implementations
    # we need to set -angle.
    matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
992

993
    if image.numel() > 0:
994
995
996
997
998
        image = image.reshape(-1, num_channels, height, width)

        _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])

        ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
999
        dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1000
1001
        theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
        grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh)
1002
        output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1003
1004

        new_height, new_width = output.shape[-2:]
1005
    else:
1006
1007
        output = image
        new_width, new_height = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
1008

1009
    return output.reshape(shape[:-3] + (num_channels, new_height, new_width))
1010
1011


1012
@_register_kernel_internal(rotate, PIL.Image.Image)
1013
def rotate_image_pil(
1014
    image: PIL.Image.Image,
1015
    angle: float,
1016
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
1017
1018
    expand: bool = False,
    center: Optional[List[float]] = None,
1019
    fill: _FillTypeJIT = None,
1020
) -> PIL.Image.Image:
1021
1022
    interpolation = _check_interpolation(interpolation)

1023
    if center is not None and expand:
1024
        warnings.warn("The provided center argument has no effect on the result if expand is True")
1025

1026
    return _FP.rotate(
1027
        image, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
1028
1029
1030
    )


1031
1032
def rotate_bounding_boxes(
    bounding_boxes: torch.Tensor,
1033
    format: datapoints.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1034
    canvas_size: Tuple[int, int],
1035
1036
1037
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
1038
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1039
1040
1041
    if center is not None and expand:
        warnings.warn("The provided center argument has no effect on the result if expand is True")

1042
1043
    return _affine_bounding_boxes_with_expand(
        bounding_boxes,
1044
        format=format,
Philip Meier's avatar
Philip Meier committed
1045
        canvas_size=canvas_size,
1046
1047
1048
1049
1050
1051
1052
        angle=-angle,
        translate=[0.0, 0.0],
        scale=1.0,
        shear=[0.0, 0.0],
        center=center,
        expand=expand,
    )
1053
1054


1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
@_register_kernel_internal(rotate, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _rotate_bounding_boxes_dispatch(
    inpt: datapoints.BoundingBoxes, angle: float, expand: bool = False, center: Optional[List[float]] = None, **kwargs
) -> datapoints.BoundingBoxes:
    output, canvas_size = rotate_bounding_boxes(
        inpt.as_subclass(torch.Tensor),
        format=inpt.format,
        canvas_size=inpt.canvas_size,
        angle=angle,
        expand=expand,
        center=center,
    )
    return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)


1070
1071
def rotate_mask(
    mask: torch.Tensor,
1072
1073
1074
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
1075
    fill: _FillTypeJIT = None,
1076
) -> torch.Tensor:
1077
1078
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1079
1080
1081
1082
1083
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = rotate_image_tensor(
1084
        mask,
1085
1086
1087
        angle=angle,
        expand=expand,
        interpolation=InterpolationMode.NEAREST,
1088
        fill=fill,
1089
1090
1091
        center=center,
    )

1092
1093
1094
1095
1096
    if needs_squeeze:
        output = output.squeeze(0)

    return output

1097

1098
1099
1100
1101
1102
1103
@_register_kernel_internal(rotate, datapoints.Mask, datapoint_wrapper=False)
def _rotate_mask_dispatch(
    inpt: datapoints.Mask,
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
1104
    fill: _FillTypeJIT = None,
1105
1106
1107
1108
1109
1110
1111
    **kwargs,
) -> datapoints.Mask:
    output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center)
    return datapoints.Mask.wrap_like(inpt, output)


@_register_kernel_internal(rotate, datapoints.Video)
1112
1113
1114
def rotate_video(
    video: torch.Tensor,
    angle: float,
1115
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
1116
1117
    expand: bool = False,
    center: Optional[List[float]] = None,
1118
    fill: _FillTypeJIT = None,
1119
1120
1121
1122
) -> torch.Tensor:
    return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)


1123
def pad(
1124
    inpt: torch.Tensor,
1125
1126
1127
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
    padding_mode: str = "constant",
1128
) -> torch.Tensor:
1129
1130
    if torch.jit.is_scripting():
        return pad_image_tensor(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
1131

1132
    _log_api_usage_once(pad)
1133

1134
1135
    kernel = _get_kernel(pad, type(inpt))
    return kernel(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
1136
1137


1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
    if isinstance(padding, int):
        pad_left = pad_right = pad_top = pad_bottom = padding
    elif isinstance(padding, (tuple, list)):
        if len(padding) == 1:
            pad_left = pad_right = pad_top = pad_bottom = padding[0]
        elif len(padding) == 2:
            pad_left = pad_right = padding[0]
            pad_top = pad_bottom = padding[1]
        elif len(padding) == 4:
            pad_left = padding[0]
            pad_top = padding[1]
            pad_right = padding[2]
            pad_bottom = padding[3]
        else:
            raise ValueError(
                f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
            )
    else:
        raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}")

    return [pad_left, pad_right, pad_top, pad_bottom]
1160

1161

1162
@_register_kernel_internal(pad, torch.Tensor)
1163
@_register_kernel_internal(pad, datapoints.Image)
1164
def pad_image_tensor(
1165
    image: torch.Tensor,
1166
1167
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1168
1169
    padding_mode: str = "constant",
) -> torch.Tensor:
1170
1171
1172
1173
1174
    # Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses
    # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
    # internally.
    torch_padding = _parse_pad_padding(padding)

1175
    if padding_mode not in ("constant", "edge", "reflect", "symmetric"):
1176
1177
1178
1179
1180
        raise ValueError(
            f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
            f"but got `'{padding_mode}'`."
        )

1181
    if fill is None:
1182
1183
1184
1185
1186
1187
        fill = 0

    if isinstance(fill, (int, float)):
        return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
    elif len(fill) == 1:
        return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode)
1188
    else:
1189
        return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
1190
1191
1192


def _pad_with_scalar_fill(
1193
    image: torch.Tensor,
1194
1195
1196
    torch_padding: List[int],
    fill: Union[int, float],
    padding_mode: str,
1197
) -> torch.Tensor:
1198
1199
    shape = image.shape
    num_channels, height, width = shape[-3:]
1200

1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
    batch_size = 1
    for s in shape[:-3]:
        batch_size *= s

    image = image.reshape(batch_size, num_channels, height, width)

    if padding_mode == "edge":
        # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
        # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
        # name.
        padding_mode = "replicate"

    if padding_mode == "constant":
        image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
    elif padding_mode in ("reflect", "replicate"):
        # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
        # TODO: See https://github.com/pytorch/pytorch/issues/40763
        dtype = image.dtype
        if not image.is_floating_point():
            needs_cast = True
            image = image.to(torch.float32)
        else:
            needs_cast = False
1224

1225
1226
1227
1228
1229
        image = torch_pad(image, torch_padding, mode=padding_mode)

        if needs_cast:
            image = image.to(dtype)
    else:  # padding_mode == "symmetric"
1230
        image = _pad_symmetric(image, torch_padding)
1231
1232

    new_height, new_width = image.shape[-2:]
1233

1234
    return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
1235
1236


1237
# TODO: This should be removed once torch_pad supports non-scalar padding values
1238
def _pad_with_vector_fill(
1239
    image: torch.Tensor,
1240
    torch_padding: List[int],
1241
    fill: List[float],
1242
    padding_mode: str,
1243
1244
1245
1246
) -> torch.Tensor:
    if padding_mode != "constant":
        raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")

1247
1248
    output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
    left, right, top, bottom = torch_padding
1249
    fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261

    if top > 0:
        output[..., :top, :] = fill
    if left > 0:
        output[..., :, :left] = fill
    if bottom > 0:
        output[..., -bottom:, :] = fill
    if right > 0:
        output[..., :, -right:] = fill
    return output


1262
pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad)
1263
1264


1265
@_register_kernel_internal(pad, datapoints.Mask)
1266
1267
def pad_mask(
    mask: torch.Tensor,
1268
1269
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1270
1271
    padding_mode: str = "constant",
) -> torch.Tensor:
1272
1273
1274
    if fill is None:
        fill = 0

1275
    if isinstance(fill, (tuple, list)):
1276
1277
        raise ValueError("Non-scalar fill value is not supported")

1278
1279
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1280
1281
1282
1283
        needs_squeeze = True
    else:
        needs_squeeze = False

1284
    output = pad_image_tensor(mask, padding=padding, fill=fill, padding_mode=padding_mode)
1285
1286
1287
1288
1289

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1290
1291


1292
1293
def pad_bounding_boxes(
    bounding_boxes: torch.Tensor,
1294
    format: datapoints.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1295
    canvas_size: Tuple[int, int],
1296
    padding: List[int],
vfdev's avatar
vfdev committed
1297
    padding_mode: str = "constant",
1298
) -> Tuple[torch.Tensor, Tuple[int, int]]:
vfdev's avatar
vfdev committed
1299
1300
1301
1302
    if padding_mode not in ["constant"]:
        # TODO: add support of other padding modes
        raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")

1303
    left, right, top, bottom = _parse_pad_padding(padding)
1304

1305
    if format == datapoints.BoundingBoxFormat.XYXY:
1306
1307
1308
        pad = [left, top, left, top]
    else:
        pad = [left, top, 0, 0]
1309
    bounding_boxes = bounding_boxes + torch.tensor(pad, dtype=bounding_boxes.dtype, device=bounding_boxes.device)
1310

Philip Meier's avatar
Philip Meier committed
1311
    height, width = canvas_size
1312
1313
    height += top + bottom
    width += left + right
Philip Meier's avatar
Philip Meier committed
1314
    canvas_size = (height, width)
1315

Philip Meier's avatar
Philip Meier committed
1316
    return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
1317
1318


1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
@_register_kernel_internal(pad, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _pad_bounding_boxes_dispatch(
    inpt: datapoints.BoundingBoxes, padding: List[int], padding_mode: str = "constant", **kwargs
) -> datapoints.BoundingBoxes:
    output, canvas_size = pad_bounding_boxes(
        inpt.as_subclass(torch.Tensor),
        format=inpt.format,
        canvas_size=inpt.canvas_size,
        padding=padding,
        padding_mode=padding_mode,
    )
    return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)


@_register_kernel_internal(pad, datapoints.Video)
1334
1335
def pad_video(
    video: torch.Tensor,
1336
1337
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1338
1339
1340
1341
1342
    padding_mode: str = "constant",
) -> torch.Tensor:
    return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode)


1343
def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
1344
1345
1346
1347
    if torch.jit.is_scripting():
        return crop_image_tensor(inpt, top=top, left=left, height=height, width=width)

    _log_api_usage_once(crop)
1348

1349
1350
    kernel = _get_kernel(crop, type(inpt))
    return kernel(inpt, top=top, left=left, height=height, width=width)
1351

1352
1353

@_register_kernel_internal(crop, torch.Tensor)
1354
@_register_kernel_internal(crop, datapoints.Image)
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
    h, w = image.shape[-2:]

    right = left + width
    bottom = top + height

    if left < 0 or top < 0 or right > w or bottom > h:
        image = image[..., max(top, 0) : bottom, max(left, 0) : right]
        torch_padding = [
            max(min(right, 0) - left, 0),
            max(right - max(w, left), 0),
            max(min(bottom, 0) - top, 0),
            max(bottom - max(h, top), 0),
        ]
        return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
    return image[..., top:bottom, left:right]


1373
crop_image_pil = _FP.crop
1374
_register_kernel_internal(crop, PIL.Image.Image)(crop_image_pil)
1375
1376


1377
1378
def crop_bounding_boxes(
    bounding_boxes: torch.Tensor,
1379
    format: datapoints.BoundingBoxFormat,
1380
1381
    top: int,
    left: int,
1382
1383
1384
    height: int,
    width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1385

1386
    # Crop or implicit pad if left and/or top have negative values:
1387
    if format == datapoints.BoundingBoxFormat.XYXY:
1388
        sub = [left, top, left, top]
1389
    else:
1390
1391
        sub = [left, top, 0, 0]

1392
    bounding_boxes = bounding_boxes - torch.tensor(sub, dtype=bounding_boxes.dtype, device=bounding_boxes.device)
Philip Meier's avatar
Philip Meier committed
1393
    canvas_size = (height, width)
1394

Philip Meier's avatar
Philip Meier committed
1395
    return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
1396
1397


1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
@_register_kernel_internal(crop, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _crop_bounding_boxes_dispatch(
    inpt: datapoints.BoundingBoxes, top: int, left: int, height: int, width: int
) -> datapoints.BoundingBoxes:
    output, canvas_size = crop_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width
    )
    return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)


@_register_kernel_internal(crop, datapoints.Mask)
1409
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = crop_image_tensor(mask, top, left, height, width)

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1422
1423


1424
@_register_kernel_internal(crop, datapoints.Video)
1425
1426
1427
1428
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
    return crop_image_tensor(video, top, left, height, width)


1429
def perspective(
1430
    inpt: torch.Tensor,
1431
1432
1433
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1434
    fill: _FillTypeJIT = None,
1435
    coefficients: Optional[List[float]] = None,
1436
) -> torch.Tensor:
1437
    if torch.jit.is_scripting():
1438
        return perspective_image_tensor(
1439
1440
1441
1442
1443
1444
            inpt,
            startpoints=startpoints,
            endpoints=endpoints,
            interpolation=interpolation,
            fill=fill,
            coefficients=coefficients,
1445
        )
1446

1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
    _log_api_usage_once(perspective)

    kernel = _get_kernel(perspective, type(inpt))
    return kernel(
        inpt,
        startpoints=startpoints,
        endpoints=endpoints,
        interpolation=interpolation,
        fill=fill,
        coefficients=coefficients,
    )

1459

1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
    # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
    # src/libImaging/Geometry.c#L394

    #
    # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #
    theta1 = torch.tensor(
        [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
    )
    theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)

    d = 0.5
    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
1475
    x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype)
1476
    base_grid[..., 0].copy_(x_grid)
1477
    y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1)
1478
1479
1480
1481
    base_grid[..., 1].copy_(y_grid)
    base_grid[..., 2].fill_(1)

    rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device))
1482
1483
1484
    shape = (1, oh * ow, 3)
    output_grid1 = base_grid.view(shape).bmm(rescaled_theta1)
    output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2))
1485
1486
1487
1488
1489

    output_grid = output_grid1.div_(output_grid2).sub_(1.0)
    return output_grid.view(1, oh, ow, 2)


1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
def _perspective_coefficients(
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]],
) -> List[float]:
    if coefficients is not None:
        if startpoints is not None and endpoints is not None:
            raise ValueError("The startpoints/endpoints and the coefficients shouldn't be defined concurrently.")
        elif len(coefficients) != 8:
            raise ValueError("Argument coefficients should have 8 float values")
        return coefficients
    elif startpoints is not None and endpoints is not None:
        return _get_perspective_coeffs(startpoints, endpoints)
    else:
        raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.")


1507
@_register_kernel_internal(perspective, torch.Tensor)
1508
@_register_kernel_internal(perspective, datapoints.Image)
1509
def perspective_image_tensor(
1510
    image: torch.Tensor,
1511
1512
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1513
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1514
    fill: _FillTypeJIT = None,
1515
    coefficients: Optional[List[float]] = None,
1516
) -> torch.Tensor:
1517
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1518
1519
    interpolation = _check_interpolation(interpolation)

1520
1521
1522
1523
    if image.numel() == 0:
        return image

    shape = image.shape
1524
    ndim = image.ndim
1525

1526
    if ndim > 4:
1527
        image = image.reshape((-1,) + shape[-3:])
1528
        needs_unsquash = True
1529
1530
1531
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
1532
1533
1534
    else:
        needs_unsquash = False

1535
    _assert_grid_transform_inputs(
1536
1537
1538
1539
1540
1541
1542
1543
        image,
        matrix=None,
        interpolation=interpolation.value,
        fill=fill,
        supported_interpolation_modes=["nearest", "bilinear"],
        coeffs=perspective_coeffs,
    )

1544
    oh, ow = shape[-2:]
1545
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1546
    grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
1547
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1548
1549

    if needs_unsquash:
1550
        output = output.reshape(shape)
1551
1552

    return output
1553
1554


1555
@_register_kernel_internal(perspective, PIL.Image.Image)
1556
def perspective_image_pil(
1557
    image: PIL.Image.Image,
1558
1559
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1560
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BICUBIC,
1561
    fill: _FillTypeJIT = None,
1562
    coefficients: Optional[List[float]] = None,
1563
) -> PIL.Image.Image:
1564
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1565
    interpolation = _check_interpolation(interpolation)
1566
    return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
1567
1568


1569
1570
def perspective_bounding_boxes(
    bounding_boxes: torch.Tensor,
1571
    format: datapoints.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1572
    canvas_size: Tuple[int, int],
1573
1574
1575
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]] = None,
1576
) -> torch.Tensor:
1577
1578
    if bounding_boxes.numel() == 0:
        return bounding_boxes
1579

1580
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1581

1582
1583
1584
1585
    original_shape = bounding_boxes.shape
    # TODO: first cast to float if bbox is int64 before convert_format_bounding_boxes
    bounding_boxes = (
        convert_format_bounding_boxes(bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
1586
    ).reshape(-1, 4)
1587

1588
1589
    dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32
    device = bounding_boxes.device
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620

    # perspective_coeffs are computed as endpoint -> start point
    # We have to invert perspective_coeffs for bboxes:
    # (x, y) - end point and (x_out, y_out) - start point
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # and we would like to get:
    # x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # and compute inv_coeffs in terms of coeffs

    denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3]
    if denom == 0:
        raise RuntimeError(
            f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. "
            f"Denominator is zero, denom={denom}"
        )

    inv_coeffs = [
        (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom,
        (-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom,
        (perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom,
        (-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom,
        (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom,
        (-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom,
        (-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom,
        (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
    ]

1621
1622
    theta1 = torch.tensor(
        [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
1623
1624
1625
1626
        dtype=dtype,
        device=device,
    )

1627
1628
1629
1630
    theta2 = torch.tensor(
        [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device
    )

1631
1632
1633
1634
    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
1635
    points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1636
1637
1638
1639
1640
    points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
    # 2) Now let's transform the points using perspective matrices
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)

1641
1642
    numer_points = torch.matmul(points, theta1.T)
    denom_points = torch.matmul(points, theta2.T)
1643
    transformed_points = numer_points.div_(denom_points)
1644
1645
1646

    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
1647
    transformed_points = transformed_points.reshape(-1, 4, 2)
1648
1649
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)

1650
1651
    out_bboxes = clamp_bounding_boxes(
        torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
1652
        format=datapoints.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
1653
        canvas_size=canvas_size,
1654
    )
1655
1656
1657

    # out_bboxes should be of shape [N boxes, 4]

1658
    return convert_format_bounding_boxes(
1659
        out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
1660
    ).reshape(original_shape)
1661
1662


1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
@_register_kernel_internal(perspective, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _perspective_bounding_boxes_dispatch(
    inpt: datapoints.BoundingBoxes,
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]] = None,
    **kwargs,
) -> datapoints.BoundingBoxes:
    output = perspective_bounding_boxes(
        inpt.as_subclass(torch.Tensor),
        format=inpt.format,
        canvas_size=inpt.canvas_size,
        startpoints=startpoints,
        endpoints=endpoints,
        coefficients=coefficients,
    )
    return datapoints.BoundingBoxes.wrap_like(inpt, output)


1682
1683
def perspective_mask(
    mask: torch.Tensor,
1684
1685
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1686
    fill: _FillTypeJIT = None,
1687
    coefficients: Optional[List[float]] = None,
1688
) -> torch.Tensor:
1689
1690
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1691
1692
1693
1694
1695
        needs_squeeze = True
    else:
        needs_squeeze = False

    output = perspective_image_tensor(
1696
        mask, startpoints, endpoints, interpolation=InterpolationMode.NEAREST, fill=fill, coefficients=coefficients
1697
    )
1698

1699
1700
1701
1702
1703
    if needs_squeeze:
        output = output.squeeze(0)

    return output

1704

1705
1706
1707
1708
1709
@_register_kernel_internal(perspective, datapoints.Mask, datapoint_wrapper=False)
def _perspective_mask_dispatch(
    inpt: datapoints.Mask,
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1710
    fill: _FillTypeJIT = None,
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
    coefficients: Optional[List[float]] = None,
    **kwargs,
) -> datapoints.Mask:
    output = perspective_mask(
        inpt.as_subclass(torch.Tensor),
        startpoints=startpoints,
        endpoints=endpoints,
        fill=fill,
        coefficients=coefficients,
    )
    return datapoints.Mask.wrap_like(inpt, output)


@_register_kernel_internal(perspective, datapoints.Video)
1725
1726
def perspective_video(
    video: torch.Tensor,
1727
1728
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1729
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1730
    fill: _FillTypeJIT = None,
1731
    coefficients: Optional[List[float]] = None,
1732
) -> torch.Tensor:
1733
1734
1735
    return perspective_image_tensor(
        video, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
    )
1736
1737


1738
def elastic(
1739
    inpt: torch.Tensor,
1740
    displacement: torch.Tensor,
1741
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1742
1743
    fill: _FillTypeJIT = None,
) -> torch.Tensor:
1744
1745
1746
1747
1748
1749
1750
    if torch.jit.is_scripting():
        return elastic_image_tensor(inpt, displacement=displacement, interpolation=interpolation, fill=fill)

    _log_api_usage_once(elastic)

    kernel = _get_kernel(elastic, type(inpt))
    return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
1751
1752


1753
1754
1755
elastic_transform = elastic


1756
@_register_kernel_internal(elastic, torch.Tensor)
1757
@_register_kernel_internal(elastic, datapoints.Image)
1758
def elastic_image_tensor(
1759
    image: torch.Tensor,
1760
    displacement: torch.Tensor,
1761
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1762
    fill: _FillTypeJIT = None,
1763
) -> torch.Tensor:
1764
1765
    interpolation = _check_interpolation(interpolation)

1766
1767
1768
1769
    if image.numel() == 0:
        return image

    shape = image.shape
1770
    ndim = image.ndim
1771

1772
    device = image.device
1773
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1774
1775
1776
1777
1778
1779
1780

    # Patch: elastic transform should support (cpu,f16) input
    is_cpu_half = device.type == "cpu" and dtype == torch.float16
    if is_cpu_half:
        image = image.to(torch.float32)
        dtype = torch.float32

1781
1782
1783
    # We are aware that if input image dtype is uint8 and displacement is float64 then
    # displacement will be casted to float32 and all computations will be done with float32
    # We can fix this later if needed
1784

1785
1786
1787
1788
    expected_shape = (1,) + shape[-2:] + (2,)
    if expected_shape != displacement.shape:
        raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")

1789
    if ndim > 4:
1790
        image = image.reshape((-1,) + shape[-3:])
1791
        needs_unsquash = True
1792
1793
1794
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
1795
1796
1797
    else:
        needs_unsquash = False

1798
1799
    if displacement.dtype != dtype or displacement.device != device:
        displacement = displacement.to(dtype=dtype, device=device)
1800

1801
1802
1803
    image_height, image_width = shape[-2:]
    grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement)
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1804
1805

    if needs_unsquash:
1806
        output = output.reshape(shape)
1807

1808
1809
1810
    if is_cpu_half:
        output = output.to(torch.float16)

1811
    return output
1812
1813


1814
@_register_kernel_internal(elastic, PIL.Image.Image)
1815
def elastic_image_pil(
1816
    image: PIL.Image.Image,
1817
    displacement: torch.Tensor,
1818
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1819
    fill: _FillTypeJIT = None,
1820
) -> PIL.Image.Image:
1821
    t_img = pil_to_tensor(image)
1822
    output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill)
1823
    return to_pil_image(output, mode=image.mode)
1824
1825


1826
def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
1827
    sy, sx = size
1828
1829
    base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype)
    x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype)
1830
1831
    base_grid[..., 0].copy_(x_grid)

1832
    y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1)
1833
1834
1835
1836
1837
    base_grid[..., 1].copy_(y_grid)

    return base_grid


1838
1839
def elastic_bounding_boxes(
    bounding_boxes: torch.Tensor,
1840
    format: datapoints.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1841
    canvas_size: Tuple[int, int],
1842
1843
    displacement: torch.Tensor,
) -> torch.Tensor:
1844
1845
    if bounding_boxes.numel() == 0:
        return bounding_boxes
1846

1847
    # TODO: add in docstring about approximation we are doing for grid inversion
1848
1849
    device = bounding_boxes.device
    dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32
1850
1851
1852

    if displacement.dtype != dtype or displacement.device != device:
        displacement = displacement.to(dtype=dtype, device=device)
1853

1854
1855
1856
1857
    original_shape = bounding_boxes.shape
    # TODO: first cast to float if bbox is int64 before convert_format_bounding_boxes
    bounding_boxes = (
        convert_format_bounding_boxes(bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
1858
    ).reshape(-1, 4)
1859

Philip Meier's avatar
Philip Meier committed
1860
    id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
1861
1862
    # We construct an approximation of inverse grid as inv_grid = id_grid - displacement
    # This is not an exact inverse of the grid
1863
    inv_grid = id_grid.sub_(displacement)
1864
1865

    # Get points from bboxes
1866
    points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1867
1868
1869
1870
1871
    if points.is_floating_point():
        points = points.ceil_()
    index_xy = points.to(dtype=torch.long)
    index_x, index_y = index_xy[:, 0], index_xy[:, 1]

1872
    # Transform points:
Philip Meier's avatar
Philip Meier committed
1873
    t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype)
1874
    transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
1875

1876
    transformed_points = transformed_points.reshape(-1, 4, 2)
1877
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
1878
1879
    out_bboxes = clamp_bounding_boxes(
        torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
1880
        format=datapoints.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
1881
        canvas_size=canvas_size,
1882
    )
1883

1884
    return convert_format_bounding_boxes(
1885
        out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
1886
    ).reshape(original_shape)
1887
1888


1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
@_register_kernel_internal(elastic, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _elastic_bounding_boxes_dispatch(
    inpt: datapoints.BoundingBoxes, displacement: torch.Tensor, **kwargs
) -> datapoints.BoundingBoxes:
    output = elastic_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, displacement=displacement
    )
    return datapoints.BoundingBoxes.wrap_like(inpt, output)


1899
1900
1901
def elastic_mask(
    mask: torch.Tensor,
    displacement: torch.Tensor,
1902
    fill: _FillTypeJIT = None,
1903
) -> torch.Tensor:
1904
1905
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1906
1907
1908
1909
        needs_squeeze = True
    else:
        needs_squeeze = False

1910
    output = elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill)
1911
1912
1913
1914
1915

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1916
1917


1918
1919
@_register_kernel_internal(elastic, datapoints.Mask, datapoint_wrapper=False)
def _elastic_mask_dispatch(
1920
    inpt: datapoints.Mask, displacement: torch.Tensor, fill: _FillTypeJIT = None, **kwargs
1921
1922
1923
1924
1925
1926
) -> datapoints.Mask:
    output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill)
    return datapoints.Mask.wrap_like(inpt, output)


@_register_kernel_internal(elastic, datapoints.Video)
1927
1928
1929
def elastic_video(
    video: torch.Tensor,
    displacement: torch.Tensor,
1930
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1931
    fill: _FillTypeJIT = None,
1932
) -> torch.Tensor:
1933
    return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
1934
1935


1936
def center_crop(inpt: torch.Tensor, output_size: List[int]) -> torch.Tensor:
1937
1938
1939
1940
1941
1942
1943
    if torch.jit.is_scripting():
        return center_crop_image_tensor(inpt, output_size=output_size)

    _log_api_usage_once(center_crop)

    kernel = _get_kernel(center_crop, type(inpt))
    return kernel(inpt, output_size=output_size)
1944
1945


1946
1947
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
    if isinstance(output_size, numbers.Number):
1948
1949
        s = int(output_size)
        return [s, s]
1950
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
1951
        return [output_size[0], output_size[0]]
1952
1953
    else:
        return list(output_size)
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972


def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]:
    return [
        (crop_width - image_width) // 2 if crop_width > image_width else 0,
        (crop_height - image_height) // 2 if crop_height > image_height else 0,
        (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
        (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
    ]


def _center_crop_compute_crop_anchor(
    crop_height: int, crop_width: int, image_height: int, image_width: int
) -> Tuple[int, int]:
    crop_top = int(round((image_height - crop_height) / 2.0))
    crop_left = int(round((image_width - crop_width) / 2.0))
    return crop_top, crop_left


1973
@_register_kernel_internal(center_crop, torch.Tensor)
1974
@_register_kernel_internal(center_crop, datapoints.Image)
1975
def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
1976
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1977
1978
1979
1980
    shape = image.shape
    if image.numel() == 0:
        return image.reshape(shape[:-2] + (crop_height, crop_width))
    image_height, image_width = shape[-2:]
1981
1982
1983

    if crop_height > image_height or crop_width > image_width:
        padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
1984
        image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0)
1985

1986
        image_height, image_width = image.shape[-2:]
1987
        if crop_width == image_width and crop_height == image_height:
1988
            return image
1989
1990

    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
1991
    return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]
1992
1993


1994
@_register_kernel_internal(center_crop, PIL.Image.Image)
1995
def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
1996
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
Philip Meier's avatar
Philip Meier committed
1997
    image_height, image_width = get_size_image_pil(image)
1998
1999
2000

    if crop_height > image_height or crop_width > image_width:
        padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
2001
        image = pad_image_pil(image, padding_ltrb, fill=0)
2002

Philip Meier's avatar
Philip Meier committed
2003
        image_height, image_width = get_size_image_pil(image)
2004
        if crop_width == image_width and crop_height == image_height:
2005
            return image
2006
2007

    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
2008
    return crop_image_pil(image, crop_top, crop_left, crop_height, crop_width)
2009
2010


2011
2012
def center_crop_bounding_boxes(
    bounding_boxes: torch.Tensor,
2013
    format: datapoints.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
2014
    canvas_size: Tuple[int, int],
2015
    output_size: List[int],
2016
) -> Tuple[torch.Tensor, Tuple[int, int]]:
2017
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
Philip Meier's avatar
Philip Meier committed
2018
    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size)
2019
2020
2021
    return crop_bounding_boxes(
        bounding_boxes, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width
    )
2022
2023


2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
@_register_kernel_internal(center_crop, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _center_crop_bounding_boxes_dispatch(
    inpt: datapoints.BoundingBoxes, output_size: List[int]
) -> datapoints.BoundingBoxes:
    output, canvas_size = center_crop_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size
    )
    return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)


@_register_kernel_internal(center_crop, datapoints.Mask)
2035
2036
2037
def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
2038
2039
2040
2041
        needs_squeeze = True
    else:
        needs_squeeze = False

2042
    output = center_crop_image_tensor(image=mask, output_size=output_size)
2043
2044
2045
2046
2047

    if needs_squeeze:
        output = output.squeeze(0)

    return output
2048
2049


2050
@_register_kernel_internal(center_crop, datapoints.Video)
2051
2052
2053
2054
def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor:
    return center_crop_image_tensor(video, output_size)


2055
def resized_crop(
2056
    inpt: torch.Tensor,
2057
2058
2059
2060
2061
2062
2063
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
    antialias: Optional[Union[str, bool]] = "warn",
2064
) -> torch.Tensor:
2065
    if torch.jit.is_scripting():
2066
        return resized_crop_image_tensor(
2067
2068
2069
2070
2071
2072
2073
2074
            inpt,
            top=top,
            left=left,
            height=height,
            width=width,
            size=size,
            interpolation=interpolation,
            antialias=antialias,
2075
        )
2076

2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
    _log_api_usage_once(resized_crop)

    kernel = _get_kernel(resized_crop, type(inpt))
    return kernel(
        inpt,
        top=top,
        left=left,
        height=height,
        width=width,
        size=size,
        interpolation=interpolation,
        antialias=antialias,
    )
2090

2091
2092

@_register_kernel_internal(resized_crop, torch.Tensor)
2093
@_register_kernel_internal(resized_crop, datapoints.Image)
2094
def resized_crop_image_tensor(
2095
    image: torch.Tensor,
2096
2097
2098
2099
2100
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
2101
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2102
    antialias: Optional[Union[str, bool]] = "warn",
2103
) -> torch.Tensor:
2104
2105
    image = crop_image_tensor(image, top, left, height, width)
    return resize_image_tensor(image, size, interpolation=interpolation, antialias=antialias)
2106
2107
2108


def resized_crop_image_pil(
2109
    image: PIL.Image.Image,
2110
2111
2112
2113
2114
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
2115
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2116
) -> PIL.Image.Image:
2117
2118
    image = crop_image_pil(image, top, left, height, width)
    return resize_image_pil(image, size, interpolation=interpolation)
2119
2120


2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
@_register_kernel_internal(resized_crop, PIL.Image.Image)
def resized_crop_image_pil_dispatch(
    image: PIL.Image.Image,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
    antialias: Optional[Union[str, bool]] = "warn",
) -> PIL.Image.Image:
    if antialias is False:
        warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
    return resized_crop_image_pil(
        image,
        top=top,
        left=left,
        height=height,
        width=width,
        size=size,
        interpolation=interpolation,
    )


2145
2146
def resized_crop_bounding_boxes(
    bounding_boxes: torch.Tensor,
2147
    format: datapoints.BoundingBoxFormat,
2148
2149
2150
2151
2152
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
2153
) -> Tuple[torch.Tensor, Tuple[int, int]]:
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
    bounding_boxes, canvas_size = crop_bounding_boxes(bounding_boxes, format, top, left, height, width)
    return resize_bounding_boxes(bounding_boxes, canvas_size=canvas_size, size=size)


@_register_kernel_internal(resized_crop, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _resized_crop_bounding_boxes_dispatch(
    inpt: datapoints.BoundingBoxes, top: int, left: int, height: int, width: int, size: List[int], **kwargs
) -> datapoints.BoundingBoxes:
    output, canvas_size = resized_crop_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size
    )
    return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)
2166
2167


2168
def resized_crop_mask(
2169
2170
2171
2172
2173
2174
2175
    mask: torch.Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
) -> torch.Tensor:
2176
2177
    mask = crop_mask(mask, top, left, height, width)
    return resize_mask(mask, size)
2178
2179


2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
@_register_kernel_internal(resized_crop, datapoints.Mask, datapoint_wrapper=False)
def _resized_crop_mask_dispatch(
    inpt: datapoints.Mask, top: int, left: int, height: int, width: int, size: List[int], **kwargs
) -> datapoints.Mask:
    output = resized_crop_mask(
        inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size
    )
    return datapoints.Mask.wrap_like(inpt, output)


@_register_kernel_internal(resized_crop, datapoints.Video)
2191
2192
2193
2194
2195
2196
2197
def resized_crop_video(
    video: torch.Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
2198
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2199
    antialias: Optional[Union[str, bool]] = "warn",
2200
2201
2202
2203
2204
2205
) -> torch.Tensor:
    return resized_crop_image_tensor(
        video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
    )


2206
2207
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True)
def five_crop(
2208
2209
    inpt: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
2210
2211
2212
2213
2214
2215
2216
    if torch.jit.is_scripting():
        return five_crop_image_tensor(inpt, size=size)

    _log_api_usage_once(five_crop)

    kernel = _get_kernel(five_crop, type(inpt))
    return kernel(inpt, size=size)
2217
2218


2219
2220
def _parse_five_crop_size(size: List[int]) -> List[int]:
    if isinstance(size, numbers.Number):
2221
2222
        s = int(size)
        size = [s, s]
2223
    elif isinstance(size, (tuple, list)) and len(size) == 1:
2224
2225
        s = size[0]
        size = [s, s]
2226
2227
2228
2229
2230
2231
2232

    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")

    return size


2233
@_register_five_ten_crop_kernel(five_crop, torch.Tensor)
2234
@_register_five_ten_crop_kernel(five_crop, datapoints.Image)
2235
def five_crop_image_tensor(
2236
    image: torch.Tensor, size: List[int]
2237
2238
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    crop_height, crop_width = _parse_five_crop_size(size)
2239
    image_height, image_width = image.shape[-2:]
2240
2241

    if crop_width > image_width or crop_height > image_height:
2242
        raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
2243

2244
2245
2246
2247
2248
    tl = crop_image_tensor(image, 0, 0, crop_height, crop_width)
    tr = crop_image_tensor(image, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop_image_tensor(image, image_height - crop_height, 0, crop_height, crop_width)
    br = crop_image_tensor(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
    center = center_crop_image_tensor(image, [crop_height, crop_width])
2249
2250
2251
2252

    return tl, tr, bl, br, center


2253
@_register_five_ten_crop_kernel(five_crop, PIL.Image.Image)
2254
def five_crop_image_pil(
2255
    image: PIL.Image.Image, size: List[int]
2256
2257
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
    crop_height, crop_width = _parse_five_crop_size(size)
Philip Meier's avatar
Philip Meier committed
2258
    image_height, image_width = get_size_image_pil(image)
2259
2260

    if crop_width > image_width or crop_height > image_height:
2261
        raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
2262

2263
2264
2265
2266
2267
    tl = crop_image_pil(image, 0, 0, crop_height, crop_width)
    tr = crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width)
    br = crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
    center = center_crop_image_pil(image, [crop_height, crop_width])
2268
2269
2270
2271

    return tl, tr, bl, br, center


2272
@_register_five_ten_crop_kernel(five_crop, datapoints.Video)
2273
2274
2275
2276
2277
2278
def five_crop_video(
    video: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    return five_crop_image_tensor(video, size)


2279
2280
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True)
def ten_crop(
2281
    inpt: torch.Tensor, size: List[int], vertical_flip: bool = False
2282
) -> Tuple[
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
2293
]:
2294
2295
2296
2297
2298
2299
2300
    if torch.jit.is_scripting():
        return ten_crop_image_tensor(inpt, size=size, vertical_flip=vertical_flip)

    _log_api_usage_once(ten_crop)

    kernel = _get_kernel(ten_crop, type(inpt))
    return kernel(inpt, size=size, vertical_flip=vertical_flip)
2301
2302


2303
@_register_five_ten_crop_kernel(ten_crop, torch.Tensor)
2304
@_register_five_ten_crop_kernel(ten_crop, datapoints.Image)
Philip Meier's avatar
Philip Meier committed
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
def ten_crop_image_tensor(
    image: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
    non_flipped = five_crop_image_tensor(image, size)
2320
2321

    if vertical_flip:
2322
        image = vertical_flip_image_tensor(image)
2323
    else:
2324
        image = horizontal_flip_image_tensor(image)
2325

Philip Meier's avatar
Philip Meier committed
2326
    flipped = five_crop_image_tensor(image, size)
2327

Philip Meier's avatar
Philip Meier committed
2328
    return non_flipped + flipped
2329
2330


2331
@_register_five_ten_crop_kernel(ten_crop, PIL.Image.Image)
Philip Meier's avatar
Philip Meier committed
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
def ten_crop_image_pil(
    image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
) -> Tuple[
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
]:
    non_flipped = five_crop_image_pil(image, size)
2347
2348

    if vertical_flip:
2349
        image = vertical_flip_image_pil(image)
2350
    else:
2351
        image = horizontal_flip_image_pil(image)
2352

Philip Meier's avatar
Philip Meier committed
2353
2354
2355
2356
2357
    flipped = five_crop_image_pil(image, size)

    return non_flipped + flipped


2358
@_register_five_ten_crop_kernel(ten_crop, datapoints.Video)
Philip Meier's avatar
Philip Meier committed
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
def ten_crop_video(
    video: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
2373
    return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip)