_geometry.py 83.7 KB
Newer Older
1
import math
2
import numbers
3
import warnings
4
from typing import Any, List, Optional, Sequence, Tuple, Union
5
6
7

import PIL.Image
import torch
8
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
9

10
from torchvision import tv_tensors
11
12
from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _pad_symmetric
13
from torchvision.transforms.functional import (
14
    _compute_resized_output_size as __compute_resized_output_size,
15
    _get_perspective_coeffs,
16
    _interpolation_modes_from_int,
17
    InterpolationMode,
18
    pil_modes_mapping,
19
20
    pil_to_tensor,
    to_pil_image,
21
)
22

23
24
from torchvision.utils import _log_api_usage_once

Nicolas Hug's avatar
Nicolas Hug committed
25
from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format
26

27
from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal
28

29

30
31
32
33
34
35
36
37
38
39
40
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
    if isinstance(interpolation, int):
        interpolation = _interpolation_modes_from_int(interpolation)
    elif not isinstance(interpolation, InterpolationMode):
        raise ValueError(
            f"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, "
            f"but got {interpolation}."
        )
    return interpolation


41
def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor:
42
    """See :class:`~torchvision.transforms.v2.RandomHorizontalFlip` for details."""
43
    if torch.jit.is_scripting():
44
        return horizontal_flip_image(inpt)
45
46
47
48
49

    _log_api_usage_once(horizontal_flip)

    kernel = _get_kernel(horizontal_flip, type(inpt))
    return kernel(inpt)
50
51


52
@_register_kernel_internal(horizontal_flip, torch.Tensor)
53
@_register_kernel_internal(horizontal_flip, tv_tensors.Image)
54
def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor:
55
56
57
    return image.flip(-1)


58
@_register_kernel_internal(horizontal_flip, PIL.Image.Image)
59
def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
60
    return _FP.hflip(image)
61
62


63
@_register_kernel_internal(horizontal_flip, tv_tensors.Mask)
64
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
65
    return horizontal_flip_image(mask)
66
67


68
def horizontal_flip_bounding_boxes(
69
    bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: Tuple[int, int]
70
) -> torch.Tensor:
71
    shape = bounding_boxes.shape
72

73
    bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
74

75
    if format == tv_tensors.BoundingBoxFormat.XYXY:
Philip Meier's avatar
Philip Meier committed
76
        bounding_boxes[:, [2, 0]] = bounding_boxes[:, [0, 2]].sub_(canvas_size[1]).neg_()
77
    elif format == tv_tensors.BoundingBoxFormat.XYWH:
Philip Meier's avatar
Philip Meier committed
78
        bounding_boxes[:, 0].add_(bounding_boxes[:, 2]).sub_(canvas_size[1]).neg_()
79
    else:  # format == tv_tensors.BoundingBoxFormat.CXCYWH:
Philip Meier's avatar
Philip Meier committed
80
        bounding_boxes[:, 0].sub_(canvas_size[1]).neg_()
81

82
    return bounding_boxes.reshape(shape)
83
84


85
86
@_register_kernel_internal(horizontal_flip, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _horizontal_flip_bounding_boxes_dispatch(inpt: tv_tensors.BoundingBoxes) -> tv_tensors.BoundingBoxes:
87
88
89
    output = horizontal_flip_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
    )
90
    return tv_tensors.wrap(output, like=inpt)
91
92


93
@_register_kernel_internal(horizontal_flip, tv_tensors.Video)
94
def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
95
    return horizontal_flip_image(video)
96
97


98
def vertical_flip(inpt: torch.Tensor) -> torch.Tensor:
99
    """See :class:`~torchvision.transforms.v2.RandomVerticalFlip` for details."""
100
    if torch.jit.is_scripting():
101
        return vertical_flip_image(inpt)
102
103
104
105
106

    _log_api_usage_once(vertical_flip)

    kernel = _get_kernel(vertical_flip, type(inpt))
    return kernel(inpt)
107
108


109
@_register_kernel_internal(vertical_flip, torch.Tensor)
110
@_register_kernel_internal(vertical_flip, tv_tensors.Image)
111
def vertical_flip_image(image: torch.Tensor) -> torch.Tensor:
112
113
114
    return image.flip(-2)


115
@_register_kernel_internal(vertical_flip, PIL.Image.Image)
116
def _vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
Philip Meier's avatar
Philip Meier committed
117
    return _FP.vflip(image)
118
119


120
@_register_kernel_internal(vertical_flip, tv_tensors.Mask)
121
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
122
    return vertical_flip_image(mask)
123
124


125
def vertical_flip_bounding_boxes(
126
    bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: Tuple[int, int]
127
) -> torch.Tensor:
128
    shape = bounding_boxes.shape
129

130
    bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
131

132
    if format == tv_tensors.BoundingBoxFormat.XYXY:
Philip Meier's avatar
Philip Meier committed
133
        bounding_boxes[:, [1, 3]] = bounding_boxes[:, [3, 1]].sub_(canvas_size[0]).neg_()
134
    elif format == tv_tensors.BoundingBoxFormat.XYWH:
Philip Meier's avatar
Philip Meier committed
135
        bounding_boxes[:, 1].add_(bounding_boxes[:, 3]).sub_(canvas_size[0]).neg_()
136
    else:  # format == tv_tensors.BoundingBoxFormat.CXCYWH:
Philip Meier's avatar
Philip Meier committed
137
        bounding_boxes[:, 1].sub_(canvas_size[0]).neg_()
138

139
    return bounding_boxes.reshape(shape)
140
141


142
143
@_register_kernel_internal(vertical_flip, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _vertical_flip_bounding_boxes_dispatch(inpt: tv_tensors.BoundingBoxes) -> tv_tensors.BoundingBoxes:
144
145
146
    output = vertical_flip_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
    )
147
    return tv_tensors.wrap(output, like=inpt)
148

149

150
@_register_kernel_internal(vertical_flip, tv_tensors.Video)
151
def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
152
    return vertical_flip_image(video)
153
154


155
156
157
158
159
160
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
hflip = horizontal_flip
vflip = vertical_flip


161
def _compute_resized_output_size(
Philip Meier's avatar
Philip Meier committed
162
    canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
163
164
165
) -> List[int]:
    if isinstance(size, int):
        size = [size]
166
167
168
169
170
    elif max_size is not None and len(size) != 1:
        raise ValueError(
            "max_size should only be passed if size specifies the length of the smaller edge, "
            "i.e. size should be an int or a sequence of length 1 in torchscript mode."
        )
Philip Meier's avatar
Philip Meier committed
171
    return __compute_resized_output_size(canvas_size, size=size, max_size=max_size)
172
173


174
def resize(
175
    inpt: torch.Tensor,
176
177
178
    size: List[int],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
    max_size: Optional[int] = None,
179
    antialias: Optional[bool] = True,
180
) -> torch.Tensor:
181
    """See :class:`~torchvision.transforms.v2.Resize` for details."""
182
    if torch.jit.is_scripting():
183
        return resize_image(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
184
185
186
187
188

    _log_api_usage_once(resize)

    kernel = _get_kernel(resize, type(inpt))
    return kernel(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
189
190


191
192
193
194
195
196
197
198
199
200
201
202
203
204
# This is an internal helper method for resize_image. We should put it here instead of keeping it
# inside resize_image due to torchscript.
# uint8 dtype support for bilinear and bicubic is limited to cpu and
# according to our benchmarks on eager, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilinear
def _do_native_uint8_resize_on_cpu(interpolation: InterpolationMode) -> bool:
    if interpolation == InterpolationMode.BILINEAR:
        if torch._dynamo.is_compiling():
            return True
        else:
            return "AVX2" in torch.backends.cpu.get_cpu_capability()

    return interpolation == InterpolationMode.BICUBIC


205
@_register_kernel_internal(resize, torch.Tensor)
206
@_register_kernel_internal(resize, tv_tensors.Image)
207
def resize_image(
208
209
    image: torch.Tensor,
    size: List[int],
210
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
211
    max_size: Optional[int] = None,
212
    antialias: Optional[bool] = True,
213
) -> torch.Tensor:
214
    interpolation = _check_interpolation(interpolation)
215
    antialias = False if antialias is None else antialias
216
217
218
    align_corners: Optional[bool] = None
    if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC:
        align_corners = False
219
    else:
220
        # The default of antialias is True from 0.17, so we don't warn or
221
222
        # error if other interpolation modes are used. This is documented.
        antialias = False
223

224
    shape = image.shape
225
    numel = image.numel()
226
    num_channels, old_height, old_width = shape[-3:]
vfdev's avatar
vfdev committed
227
    new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
228

229
230
    if (new_height, new_width) == (old_height, old_width):
        return image
231
    elif numel > 0:
232
        dtype = image.dtype
233
234
235
236
        acceptable_dtypes = [torch.float32, torch.float64]
        if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
            # uint8 dtype can be included for cpu and cuda input if nearest mode
            acceptable_dtypes.append(torch.uint8)
237
        elif image.device.type == "cpu":
238
            if _do_native_uint8_resize_on_cpu(interpolation):
239
                acceptable_dtypes.append(torch.uint8)
240

241
        image = image.reshape(-1, num_channels, old_height, old_width)
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        strides = image.stride()
        if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
            # There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as
            # contiguous even though the input is un-ambiguously channels_last (https://github.com/pytorch/pytorch/issues/68430).
            # In particular this happens for the typical torchvision use-case of single CHW images where we fake the batch dim
            # to become 1CHW. Below, we restride those tensors to trick torch core into properly allocating the output as
            # channels_last, thus preserving the memory format of the input. This is not just for format consistency:
            # for uint8 bilinear images, this also avoids an extra copy (re-packing) of the output and saves time.
            # TODO: when https://github.com/pytorch/pytorch/issues/68430 is fixed (possibly by https://github.com/pytorch/pytorch/pull/100373),
            # we should be able to remove this hack.
            new_strides = list(strides)
            new_strides[0] = numel
            image = image.as_strided((1, num_channels, old_height, old_width), new_strides)

        need_cast = dtype not in acceptable_dtypes
257
258
259
260
        if need_cast:
            image = image.to(dtype=torch.float32)

        image = interpolate(
261
262
            image,
            size=[new_height, new_width],
263
264
            mode=interpolation.value,
            align_corners=align_corners,
265
266
            antialias=antialias,
        )
267

268
269
        if need_cast:
            if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
270
                # This path is hit on non-AVX archs, or on GPU.
271
                image = image.clamp_(min=0, max=255)
272
273
274
            if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
                image = image.round_()
            image = image.to(dtype=dtype)
275

276
    return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
277
278


279
def _resize_image_pil(
280
    image: PIL.Image.Image,
281
    size: Union[Sequence[int], int],
282
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
283
284
    max_size: Optional[int] = None,
) -> PIL.Image.Image:
285
286
287
288
289
290
291
    old_height, old_width = image.height, image.width
    new_height, new_width = _compute_resized_output_size(
        (old_height, old_width),
        size=size,  # type: ignore[arg-type]
        max_size=max_size,
    )

292
    interpolation = _check_interpolation(interpolation)
293
294
295
296
297

    if (new_height, new_width) == (old_height, old_width):
        return image

    return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation])
298
299


300
@_register_kernel_internal(resize, PIL.Image.Image)
301
def __resize_image_pil_dispatch(
302
303
304
305
    image: PIL.Image.Image,
    size: Union[Sequence[int], int],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
    max_size: Optional[int] = None,
306
    antialias: Optional[bool] = True,
307
308
309
) -> PIL.Image.Image:
    if antialias is False:
        warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
310
    return _resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size)
311
312


313
314
315
def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
316
317
318
319
        needs_squeeze = True
    else:
        needs_squeeze = False

320
    output = resize_image(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
321
322
323
324
325

    if needs_squeeze:
        output = output.squeeze(0)

    return output
326
327


328
@_register_kernel_internal(resize, tv_tensors.Mask, tv_tensor_wrapper=False)
329
def _resize_mask_dispatch(
330
331
    inpt: tv_tensors.Mask, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> tv_tensors.Mask:
332
    output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size)
333
    return tv_tensors.wrap(output, like=inpt)
334
335


336
def resize_bounding_boxes(
Philip Meier's avatar
Philip Meier committed
337
    bounding_boxes: torch.Tensor, canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
338
) -> Tuple[torch.Tensor, Tuple[int, int]]:
Philip Meier's avatar
Philip Meier committed
339
340
    old_height, old_width = canvas_size
    new_height, new_width = _compute_resized_output_size(canvas_size, size=size, max_size=max_size)
341
342

    if (new_height, new_width) == (old_height, old_width):
Philip Meier's avatar
Philip Meier committed
343
        return bounding_boxes, canvas_size
344

345
346
    w_ratio = new_width / old_width
    h_ratio = new_height / old_height
347
    ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_boxes.device)
348
    return (
349
        bounding_boxes.mul(ratios).to(bounding_boxes.dtype),
350
351
        (new_height, new_width),
    )
352
353


354
@_register_kernel_internal(resize, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
355
def _resize_bounding_boxes_dispatch(
356
357
    inpt: tv_tensors.BoundingBoxes, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> tv_tensors.BoundingBoxes:
358
359
360
    output, canvas_size = resize_bounding_boxes(
        inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size
    )
361
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
362
363


364
@_register_kernel_internal(resize, tv_tensors.Video)
365
366
367
def resize_video(
    video: torch.Tensor,
    size: List[int],
368
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
369
    max_size: Optional[int] = None,
370
    antialias: Optional[bool] = True,
371
) -> torch.Tensor:
372
    return resize_image(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
373
374


375
def affine(
376
    inpt: torch.Tensor,
377
378
379
380
381
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
382
    fill: _FillTypeJIT = None,
383
    center: Optional[List[float]] = None,
384
) -> torch.Tensor:
385
    """See :class:`~torchvision.transforms.v2.RandomAffine` for details."""
386
    if torch.jit.is_scripting():
387
        return affine_image(
388
            inpt,
389
            angle=angle,
390
391
392
393
394
395
396
            translate=translate,
            scale=scale,
            shear=shear,
            interpolation=interpolation,
            fill=fill,
            center=center,
        )
397
398
399
400
401
402
403
404
405
406
407
408
409
410

    _log_api_usage_once(affine)

    kernel = _get_kernel(affine, type(inpt))
    return kernel(
        inpt,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=interpolation,
        fill=fill,
        center=center,
    )
411
412


413
def _affine_parse_args(
414
    angle: Union[int, float],
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
    translate: List[float],
    scale: float,
    shear: List[float],
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    center: Optional[List[float]] = None,
) -> Tuple[float, List[float], List[float], Optional[List[float]]]:
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if not isinstance(translate, (list, tuple)):
        raise TypeError("Argument translate should be a sequence")

    if len(translate) != 2:
        raise ValueError("Argument translate should be a sequence of length 2")

    if scale <= 0.0:
        raise ValueError("Argument scale should be positive")

    if not isinstance(shear, (numbers.Number, (list, tuple))):
        raise TypeError("Shear should be either a single value or a sequence of two values")

    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")

    if isinstance(angle, int):
        angle = float(angle)

    if isinstance(translate, tuple):
        translate = list(translate)

    if isinstance(shear, numbers.Number):
        shear = [shear, 0.0]

    if isinstance(shear, tuple):
        shear = list(shear)

    if len(shear) == 1:
        shear = [shear[0], shear[0]]

    if len(shear) != 2:
        raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")

457
458
459
460
461
    if center is not None:
        if not isinstance(center, (list, tuple)):
            raise TypeError("Argument center should be a sequence")
        else:
            center = [float(c) for c in center]
462
463
464
465

    return angle, translate, shear, center


466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
def _get_inverse_affine_matrix(
    center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
) -> List[float]:
    # Helper method to compute inverse matrix for affine transformation

    # Pillow requires inverse affine transformation matrix:
    # Affine matrix is : M = T * C * RotateScaleShear * C^-1
    #
    # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
    #       C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
    #       RotateScaleShear is rotation with scale and shear matrix
    #
    #       RotateScaleShear(a, s, (sx, sy)) =
    #       = R(a) * S(s) * SHy(sy) * SHx(sx)
    #       = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
    #         [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
    #         [ 0                    , 0                                      , 1 ]
    # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
    # SHx(s) = [1, -tan(s)] and SHy(s) = [1      , 0]
    #          [0, 1      ]              [-tan(s), 1]
    #
    # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1

    rot = math.radians(angle)
    sx = math.radians(shear[0])
    sy = math.radians(shear[1])

    cx, cy = center
    tx, ty = translate

    # Cached results
    cos_sy = math.cos(sy)
    tan_sx = math.tan(sx)
    rot_minus_sy = rot - sy
    cx_plus_tx = cx + tx
    cy_plus_ty = cy + ty

    # Rotate Scale Shear (RSS) without scaling
    a = math.cos(rot_minus_sy) / cos_sy
    b = -(a * tan_sx + math.sin(rot))
    c = math.sin(rot_minus_sy) / cos_sy
    d = math.cos(rot) - c * tan_sx

    if inverted:
        # Inverted rotation matrix with scale and shear
        # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
        matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0]
        # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
        # and then apply center translation: C * RSS^-1 * C^-1 * T^-1
        matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty
        matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty
    else:
        matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0]
        # Apply inverse of center translation: RSS * C^-1
        # and then apply translation and center : T * C * RSS * C^-1
        matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy
        matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy

    return matrix


def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
    # Inspired of PIL implementation:
    # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054

    # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
    # Points are shifted due to affine matrix torch convention about
    # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
    half_w = 0.5 * w
    half_h = 0.5 * h
    pts = torch.tensor(
        [
            [-half_w, -half_h, 1.0],
            [-half_w, half_h, 1.0],
            [half_w, half_h, 1.0],
            [half_w, -half_h, 1.0],
        ]
    )
    theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
    new_pts = torch.matmul(pts, theta.T)
    min_vals, max_vals = new_pts.aminmax(dim=0)

    # shift points to [0, w] and [0, h] interval to match PIL results
    halfs = torch.tensor((half_w, half_h))
    min_vals.add_(halfs)
    max_vals.add_(halfs)

    # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
    tol = 1e-4
    inv_tol = 1.0 / tol
    cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_()
    cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_()
    size = cmax.sub_(cmin)
    return int(size[0]), int(size[1])  # w, h


562
def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill: _FillTypeJIT) -> torch.Tensor:
563
564
565
566
567
568
569
570
571
572
    input_shape = img.shape
    output_height, output_width = grid.shape[1], grid.shape[2]
    num_channels, input_height, input_width = input_shape[-3:]
    output_shape = input_shape[:-3] + (num_channels, output_height, output_width)

    if img.numel() == 0:
        return img.reshape(output_shape)

    img = img.reshape(-1, num_channels, input_height, input_width)
    squashed_batch_size = img.shape[0]
573

574
575
576
577
    # We are using context knowledge that grid should have float dtype
    fp = img.dtype == grid.dtype
    float_img = img if fp else img.to(grid.dtype)

578
    if squashed_batch_size > 1:
579
        # Apply same grid to a batch of images
580
        grid = grid.expand(squashed_batch_size, -1, -1, -1)
581
582
583

    # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
    if fill is not None:
584
585
586
        mask = torch.ones(
            (squashed_batch_size, 1, input_height, input_width), dtype=float_img.dtype, device=float_img.device
        )
587
588
589
590
591
592
593
594
        float_img = torch.cat((float_img, mask), dim=1)

    float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False)

    # Fill with required color
    if fill is not None:
        float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
        mask = mask.expand_as(float_img)
595
        fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)]  # type: ignore[arg-type]
596
597
598
599
600
601
602
603
604
        fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
        if mode == "nearest":
            bool_mask = mask < 0.5
            float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
        else:  # 'bilinear'
            # The following is mathematically equivalent to:
            # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
            float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)

605
606
    img = float_img.round_().to(img.dtype) if not fp else float_img

607
    return img.reshape(output_shape)
608
609
610
611
612
613


def _assert_grid_transform_inputs(
    image: torch.Tensor,
    matrix: Optional[List[float]],
    interpolation: str,
614
    fill: _FillTypeJIT,
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
    supported_interpolation_modes: List[str],
    coeffs: Optional[List[float]] = None,
) -> None:
    if matrix is not None:
        if not isinstance(matrix, list):
            raise TypeError("Argument matrix should be a list")
        elif len(matrix) != 6:
            raise ValueError("Argument matrix should have 6 float values")

    if coeffs is not None and len(coeffs) != 8:
        raise ValueError("Argument coeffs should have 8 float values")

    if fill is not None:
        if isinstance(fill, (tuple, list)):
            length = len(fill)
            num_channels = image.shape[-3]
            if length > 1 and length != num_channels:
                raise ValueError(
                    "The number of elements in 'fill' cannot broadcast to match the number of "
                    f"channels of the image ({length} != {num_channels})"
                )
        elif not isinstance(fill, (int, float)):
            raise ValueError("Argument fill should be either int, float, tuple or list")

    if interpolation not in supported_interpolation_modes:
        raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")


def _affine_grid(
    theta: torch.Tensor,
    w: int,
    h: int,
    ow: int,
    oh: int,
) -> torch.Tensor:
    # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
    # AffineGridGenerator.cpp#L18
    # Difference with AffineGridGenerator is that:
    # 1) we normalize grid values after applying theta
    # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
    dtype = theta.dtype
    device = theta.device

    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
    x_grid = torch.linspace((1.0 - ow) * 0.5, (ow - 1.0) * 0.5, steps=ow, device=device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace((1.0 - oh) * 0.5, (oh - 1.0) * 0.5, steps=oh, device=device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
    base_grid[..., 2].fill_(1)

    rescaled_theta = theta.transpose(1, 2).div_(torch.tensor([0.5 * w, 0.5 * h], dtype=dtype, device=device))
    output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
    return output_grid.view(1, oh, ow, 2)


670
@_register_kernel_internal(affine, torch.Tensor)
671
@_register_kernel_internal(affine, tv_tensors.Image)
672
def affine_image(
673
    image: torch.Tensor,
674
    angle: Union[int, float],
675
676
677
    translate: List[float],
    scale: float,
    shear: List[float],
678
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
679
    fill: _FillTypeJIT = None,
680
681
    center: Optional[List[float]] = None,
) -> torch.Tensor:
682
683
    interpolation = _check_interpolation(interpolation)

684
685
    angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

686
687
    height, width = image.shape[-2:]

688
689
690
    center_f = [0.0, 0.0]
    if center is not None:
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
691
        center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
692

693
    translate_f = [float(t) for t in translate]
694
695
    matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)

696
697
    _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])

698
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
699
700
    theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
    grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
701
    return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
702
703


704
@_register_kernel_internal(affine, PIL.Image.Image)
705
def _affine_image_pil(
706
    image: PIL.Image.Image,
707
    angle: Union[int, float],
708
709
710
    translate: List[float],
    scale: float,
    shear: List[float],
711
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
712
    fill: _FillTypeJIT = None,
713
714
    center: Optional[List[float]] = None,
) -> PIL.Image.Image:
715
    interpolation = _check_interpolation(interpolation)
716
717
718
719
720
721
    angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

    # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
    # it is visually better to estimate the center without 0.5 offset
    # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
    if center is None:
722
        height, width = _get_size_image_pil(image)
723
724
725
        center = [width * 0.5, height * 0.5]
    matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)

726
    return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
727
728


729
730
def _affine_bounding_boxes_with_expand(
    bounding_boxes: torch.Tensor,
731
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
732
    canvas_size: Tuple[int, int],
733
734
735
736
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
737
    center: Optional[List[float]] = None,
738
    expand: bool = False,
739
) -> Tuple[torch.Tensor, Tuple[int, int]]:
740
    if bounding_boxes.numel() == 0:
Philip Meier's avatar
Philip Meier committed
741
        return bounding_boxes, canvas_size
742
743
744
745
746
747
748

    original_shape = bounding_boxes.shape
    original_dtype = bounding_boxes.dtype
    bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
    dtype = bounding_boxes.dtype
    device = bounding_boxes.device
    bounding_boxes = (
Nicolas Hug's avatar
Nicolas Hug committed
749
        convert_bounding_box_format(
750
            bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True
751
752
753
        )
    ).reshape(-1, 4)

754
755
756
    angle, translate, shear, center = _affine_parse_args(
        angle, translate, scale, shear, InterpolationMode.NEAREST, center
    )
757

758
    if center is None:
Philip Meier's avatar
Philip Meier committed
759
        height, width = canvas_size
760
761
        center = [width * 0.5, height * 0.5]

762
763
764
765
766
767
768
    affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
    transposed_affine_matrix = (
        torch.tensor(
            affine_vector,
            dtype=dtype,
            device=device,
        )
769
        .reshape(2, 3)
770
771
        .T
    )
772
773
774
775
    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
776
    points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
777
    points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
778
    # 2) Now let's transform the points using affine matrix
779
    transformed_points = torch.matmul(points, transposed_affine_matrix)
780
781
    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
782
    transformed_points = transformed_points.reshape(-1, 4, 2)
783
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
784
    out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
785
786
787
788

    if expand:
        # Compute minimum point for transformed image frame:
        # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
Philip Meier's avatar
Philip Meier committed
789
        height, width = canvas_size
790
791
792
        points = torch.tensor(
            [
                [0.0, 0.0, 1.0],
793
794
795
                [0.0, float(height), 1.0],
                [float(width), float(height), 1.0],
                [float(width), 0.0, 1.0],
796
797
798
799
            ],
            dtype=dtype,
            device=device,
        )
800
        new_points = torch.matmul(points, transposed_affine_matrix)
801
        tr = torch.amin(new_points, dim=0, keepdim=True)
802
        # Translate bounding boxes
803
        out_bboxes.sub_(tr.repeat((1, 2)))
804
805
        # Estimate meta-data for image with inverted=True
        affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
806
        new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
Philip Meier's avatar
Philip Meier committed
807
        canvas_size = (new_height, new_width)
808

809
    out_bboxes = clamp_bounding_boxes(out_bboxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size)
Nicolas Hug's avatar
Nicolas Hug committed
810
    out_bboxes = convert_bounding_box_format(
811
        out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
812
813
814
    ).reshape(original_shape)

    out_bboxes = out_bboxes.to(original_dtype)
Philip Meier's avatar
Philip Meier committed
815
    return out_bboxes, canvas_size
816
817


818
819
def affine_bounding_boxes(
    bounding_boxes: torch.Tensor,
820
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
821
    canvas_size: Tuple[int, int],
822
    angle: Union[int, float],
823
824
825
826
827
    translate: List[float],
    scale: float,
    shear: List[float],
    center: Optional[List[float]] = None,
) -> torch.Tensor:
828
829
    out_box, _ = _affine_bounding_boxes_with_expand(
        bounding_boxes,
830
        format=format,
Philip Meier's avatar
Philip Meier committed
831
        canvas_size=canvas_size,
832
833
834
835
836
837
838
839
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        center=center,
        expand=False,
    )
    return out_box
840
841


842
@_register_kernel_internal(affine, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
843
def _affine_bounding_boxes_dispatch(
844
    inpt: tv_tensors.BoundingBoxes,
845
846
847
848
849
850
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
    center: Optional[List[float]] = None,
    **kwargs,
851
) -> tv_tensors.BoundingBoxes:
852
853
854
855
856
857
858
859
860
861
    output = affine_bounding_boxes(
        inpt.as_subclass(torch.Tensor),
        format=inpt.format,
        canvas_size=inpt.canvas_size,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        center=center,
    )
862
    return tv_tensors.wrap(output, like=inpt)
863
864


865
866
def affine_mask(
    mask: torch.Tensor,
867
    angle: Union[int, float],
868
869
870
    translate: List[float],
    scale: float,
    shear: List[float],
871
    fill: _FillTypeJIT = None,
872
873
    center: Optional[List[float]] = None,
) -> torch.Tensor:
874
875
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
876
877
878
879
        needs_squeeze = True
    else:
        needs_squeeze = False

880
    output = affine_image(
881
        mask,
882
883
884
885
886
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=InterpolationMode.NEAREST,
887
        fill=fill,
888
889
890
        center=center,
    )

891
892
893
894
895
    if needs_squeeze:
        output = output.squeeze(0)

    return output

896

897
@_register_kernel_internal(affine, tv_tensors.Mask, tv_tensor_wrapper=False)
898
def _affine_mask_dispatch(
899
    inpt: tv_tensors.Mask,
900
901
902
903
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
904
    fill: _FillTypeJIT = None,
905
906
    center: Optional[List[float]] = None,
    **kwargs,
907
) -> tv_tensors.Mask:
908
909
910
911
912
913
914
915
916
    output = affine_mask(
        inpt.as_subclass(torch.Tensor),
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        fill=fill,
        center=center,
    )
917
    return tv_tensors.wrap(output, like=inpt)
918
919


920
@_register_kernel_internal(affine, tv_tensors.Video)
921
922
923
924
925
926
def affine_video(
    video: torch.Tensor,
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
927
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
928
    fill: _FillTypeJIT = None,
929
930
    center: Optional[List[float]] = None,
) -> torch.Tensor:
931
    return affine_image(
932
933
934
935
936
937
938
939
940
941
942
        video,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=interpolation,
        fill=fill,
        center=center,
    )


943
def rotate(
944
    inpt: torch.Tensor,
945
    angle: float,
946
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
947
    expand: bool = False,
948
    center: Optional[List[float]] = None,
949
950
    fill: _FillTypeJIT = None,
) -> torch.Tensor:
951
    """See :class:`~torchvision.transforms.v2.RandomRotation` for details."""
952
    if torch.jit.is_scripting():
953
        return rotate_image(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
954

955
    _log_api_usage_once(rotate)
956

957
958
959
960
961
    kernel = _get_kernel(rotate, type(inpt))
    return kernel(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center)


@_register_kernel_internal(rotate, torch.Tensor)
962
@_register_kernel_internal(rotate, tv_tensors.Image)
963
def rotate_image(
964
    image: torch.Tensor,
965
    angle: float,
966
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
967
968
    expand: bool = False,
    center: Optional[List[float]] = None,
969
    fill: _FillTypeJIT = None,
970
) -> torch.Tensor:
971
972
    interpolation = _check_interpolation(interpolation)

973
    input_height, input_width = image.shape[-2:]
974

975
976
    center_f = [0.0, 0.0]
    if center is not None:
977
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
978
        center_f = [(c - s * 0.5) for c, s in zip(center, [input_width, input_height])]
979
980
981
982

    # due to current incoherence of rotation angle direction between affine and rotate implementations
    # we need to set -angle.
    matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
983

984
    _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
985

986
987
988
989
990
991
992
    output_width, output_height = (
        _compute_affine_output_size(matrix, input_width, input_height) if expand else (input_width, input_height)
    )
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
    theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
    grid = _affine_grid(theta, w=input_width, h=input_height, ow=output_width, oh=output_height)
    return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
993
994


995
@_register_kernel_internal(rotate, PIL.Image.Image)
996
def _rotate_image_pil(
997
    image: PIL.Image.Image,
998
    angle: float,
999
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
1000
1001
    expand: bool = False,
    center: Optional[List[float]] = None,
1002
    fill: _FillTypeJIT = None,
1003
) -> PIL.Image.Image:
1004
1005
    interpolation = _check_interpolation(interpolation)

1006
    return _FP.rotate(
1007
        image, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
1008
1009
1010
    )


1011
1012
def rotate_bounding_boxes(
    bounding_boxes: torch.Tensor,
1013
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1014
    canvas_size: Tuple[int, int],
1015
1016
1017
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
1018
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1019
1020
    return _affine_bounding_boxes_with_expand(
        bounding_boxes,
1021
        format=format,
Philip Meier's avatar
Philip Meier committed
1022
        canvas_size=canvas_size,
1023
1024
1025
1026
1027
1028
1029
        angle=-angle,
        translate=[0.0, 0.0],
        scale=1.0,
        shear=[0.0, 0.0],
        center=center,
        expand=expand,
    )
1030
1031


1032
@_register_kernel_internal(rotate, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
1033
def _rotate_bounding_boxes_dispatch(
1034
1035
    inpt: tv_tensors.BoundingBoxes, angle: float, expand: bool = False, center: Optional[List[float]] = None, **kwargs
) -> tv_tensors.BoundingBoxes:
1036
1037
1038
1039
1040
1041
1042
1043
    output, canvas_size = rotate_bounding_boxes(
        inpt.as_subclass(torch.Tensor),
        format=inpt.format,
        canvas_size=inpt.canvas_size,
        angle=angle,
        expand=expand,
        center=center,
    )
1044
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
1045
1046


1047
1048
def rotate_mask(
    mask: torch.Tensor,
1049
1050
1051
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
1052
    fill: _FillTypeJIT = None,
1053
) -> torch.Tensor:
1054
1055
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1056
1057
1058
1059
        needs_squeeze = True
    else:
        needs_squeeze = False

1060
    output = rotate_image(
1061
        mask,
1062
1063
1064
        angle=angle,
        expand=expand,
        interpolation=InterpolationMode.NEAREST,
1065
        fill=fill,
1066
1067
1068
        center=center,
    )

1069
1070
1071
1072
1073
    if needs_squeeze:
        output = output.squeeze(0)

    return output

1074

1075
@_register_kernel_internal(rotate, tv_tensors.Mask, tv_tensor_wrapper=False)
1076
def _rotate_mask_dispatch(
1077
    inpt: tv_tensors.Mask,
1078
1079
1080
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
1081
    fill: _FillTypeJIT = None,
1082
    **kwargs,
1083
) -> tv_tensors.Mask:
1084
    output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center)
1085
    return tv_tensors.wrap(output, like=inpt)
1086
1087


1088
@_register_kernel_internal(rotate, tv_tensors.Video)
1089
1090
1091
def rotate_video(
    video: torch.Tensor,
    angle: float,
1092
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
1093
1094
    expand: bool = False,
    center: Optional[List[float]] = None,
1095
    fill: _FillTypeJIT = None,
1096
) -> torch.Tensor:
1097
    return rotate_image(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
1098
1099


1100
def pad(
1101
    inpt: torch.Tensor,
1102
1103
1104
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
    padding_mode: str = "constant",
1105
) -> torch.Tensor:
1106
    """See :class:`~torchvision.transforms.v2.Pad` for details."""
1107
    if torch.jit.is_scripting():
1108
        return pad_image(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
1109

1110
    _log_api_usage_once(pad)
1111

1112
1113
    kernel = _get_kernel(pad, type(inpt))
    return kernel(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
1114
1115


1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
    if isinstance(padding, int):
        pad_left = pad_right = pad_top = pad_bottom = padding
    elif isinstance(padding, (tuple, list)):
        if len(padding) == 1:
            pad_left = pad_right = pad_top = pad_bottom = padding[0]
        elif len(padding) == 2:
            pad_left = pad_right = padding[0]
            pad_top = pad_bottom = padding[1]
        elif len(padding) == 4:
            pad_left = padding[0]
            pad_top = padding[1]
            pad_right = padding[2]
            pad_bottom = padding[3]
        else:
            raise ValueError(
                f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
            )
    else:
        raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}")

    return [pad_left, pad_right, pad_top, pad_bottom]
1138

1139

1140
@_register_kernel_internal(pad, torch.Tensor)
1141
@_register_kernel_internal(pad, tv_tensors.Image)
1142
def pad_image(
1143
    image: torch.Tensor,
1144
1145
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1146
1147
    padding_mode: str = "constant",
) -> torch.Tensor:
1148
    # Be aware that while `padding` has order `[left, top, right, bottom]`, `torch_padding` uses
1149
1150
1151
1152
    # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
    # internally.
    torch_padding = _parse_pad_padding(padding)

1153
    if padding_mode not in ("constant", "edge", "reflect", "symmetric"):
1154
1155
1156
1157
1158
        raise ValueError(
            f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
            f"but got `'{padding_mode}'`."
        )

1159
    if fill is None:
1160
1161
1162
1163
1164
1165
        fill = 0

    if isinstance(fill, (int, float)):
        return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
    elif len(fill) == 1:
        return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode)
1166
    else:
1167
        return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
1168
1169
1170


def _pad_with_scalar_fill(
1171
    image: torch.Tensor,
1172
1173
1174
    torch_padding: List[int],
    fill: Union[int, float],
    padding_mode: str,
1175
) -> torch.Tensor:
1176
1177
    shape = image.shape
    num_channels, height, width = shape[-3:]
1178

1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
    batch_size = 1
    for s in shape[:-3]:
        batch_size *= s

    image = image.reshape(batch_size, num_channels, height, width)

    if padding_mode == "edge":
        # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
        # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
        # name.
        padding_mode = "replicate"

    if padding_mode == "constant":
        image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
    elif padding_mode in ("reflect", "replicate"):
        # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
        # TODO: See https://github.com/pytorch/pytorch/issues/40763
        dtype = image.dtype
        if not image.is_floating_point():
            needs_cast = True
            image = image.to(torch.float32)
        else:
            needs_cast = False
1202

1203
1204
1205
1206
1207
        image = torch_pad(image, torch_padding, mode=padding_mode)

        if needs_cast:
            image = image.to(dtype)
    else:  # padding_mode == "symmetric"
1208
        image = _pad_symmetric(image, torch_padding)
1209
1210

    new_height, new_width = image.shape[-2:]
1211

1212
    return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
1213
1214


1215
# TODO: This should be removed once torch_pad supports non-scalar padding values
1216
def _pad_with_vector_fill(
1217
    image: torch.Tensor,
1218
    torch_padding: List[int],
1219
    fill: List[float],
1220
    padding_mode: str,
1221
1222
1223
1224
) -> torch.Tensor:
    if padding_mode != "constant":
        raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")

1225
1226
    output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
    left, right, top, bottom = torch_padding
1227
1228
1229
1230
1231

    # We are creating the tensor in the autodetected dtype first and convert to the right one after to avoid an implicit
    # float -> int conversion. That happens for example for the valid input of a uint8 image with floating point fill
    # value.
    fill = torch.tensor(fill, device=image.device).to(dtype=image.dtype).reshape(-1, 1, 1)
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243

    if top > 0:
        output[..., :top, :] = fill
    if left > 0:
        output[..., :, :left] = fill
    if bottom > 0:
        output[..., -bottom:, :] = fill
    if right > 0:
        output[..., :, -right:] = fill
    return output


1244
_pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad)
1245
1246


1247
@_register_kernel_internal(pad, tv_tensors.Mask)
1248
1249
def pad_mask(
    mask: torch.Tensor,
1250
1251
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1252
1253
    padding_mode: str = "constant",
) -> torch.Tensor:
1254
1255
1256
    if fill is None:
        fill = 0

1257
    if isinstance(fill, (tuple, list)):
1258
1259
        raise ValueError("Non-scalar fill value is not supported")

1260
1261
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1262
1263
1264
1265
        needs_squeeze = True
    else:
        needs_squeeze = False

1266
    output = pad_image(mask, padding=padding, fill=fill, padding_mode=padding_mode)
1267
1268
1269
1270
1271

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1272
1273


1274
1275
def pad_bounding_boxes(
    bounding_boxes: torch.Tensor,
1276
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1277
    canvas_size: Tuple[int, int],
1278
    padding: List[int],
vfdev's avatar
vfdev committed
1279
    padding_mode: str = "constant",
1280
) -> Tuple[torch.Tensor, Tuple[int, int]]:
vfdev's avatar
vfdev committed
1281
1282
1283
1284
    if padding_mode not in ["constant"]:
        # TODO: add support of other padding modes
        raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")

1285
    left, right, top, bottom = _parse_pad_padding(padding)
1286

1287
    if format == tv_tensors.BoundingBoxFormat.XYXY:
1288
1289
1290
        pad = [left, top, left, top]
    else:
        pad = [left, top, 0, 0]
1291
    bounding_boxes = bounding_boxes + torch.tensor(pad, dtype=bounding_boxes.dtype, device=bounding_boxes.device)
1292

Philip Meier's avatar
Philip Meier committed
1293
    height, width = canvas_size
1294
1295
    height += top + bottom
    width += left + right
Philip Meier's avatar
Philip Meier committed
1296
    canvas_size = (height, width)
1297

Philip Meier's avatar
Philip Meier committed
1298
    return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
1299
1300


1301
@_register_kernel_internal(pad, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
1302
def _pad_bounding_boxes_dispatch(
1303
1304
    inpt: tv_tensors.BoundingBoxes, padding: List[int], padding_mode: str = "constant", **kwargs
) -> tv_tensors.BoundingBoxes:
1305
1306
1307
1308
1309
1310
1311
    output, canvas_size = pad_bounding_boxes(
        inpt.as_subclass(torch.Tensor),
        format=inpt.format,
        canvas_size=inpt.canvas_size,
        padding=padding,
        padding_mode=padding_mode,
    )
1312
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
1313
1314


1315
@_register_kernel_internal(pad, tv_tensors.Video)
1316
1317
def pad_video(
    video: torch.Tensor,
1318
1319
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1320
1321
    padding_mode: str = "constant",
) -> torch.Tensor:
1322
    return pad_image(video, padding, fill=fill, padding_mode=padding_mode)
1323
1324


1325
def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
1326
    """See :class:`~torchvision.transforms.v2.RandomCrop` for details."""
1327
    if torch.jit.is_scripting():
1328
        return crop_image(inpt, top=top, left=left, height=height, width=width)
1329
1330

    _log_api_usage_once(crop)
1331

1332
1333
    kernel = _get_kernel(crop, type(inpt))
    return kernel(inpt, top=top, left=left, height=height, width=width)
1334

1335
1336

@_register_kernel_internal(crop, torch.Tensor)
1337
@_register_kernel_internal(crop, tv_tensors.Image)
1338
def crop_image(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
    h, w = image.shape[-2:]

    right = left + width
    bottom = top + height

    if left < 0 or top < 0 or right > w or bottom > h:
        image = image[..., max(top, 0) : bottom, max(left, 0) : right]
        torch_padding = [
            max(min(right, 0) - left, 0),
            max(right - max(w, left), 0),
            max(min(bottom, 0) - top, 0),
            max(bottom - max(h, top), 0),
        ]
        return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
    return image[..., top:bottom, left:right]


1356
1357
_crop_image_pil = _FP.crop
_register_kernel_internal(crop, PIL.Image.Image)(_crop_image_pil)
1358
1359


1360
1361
def crop_bounding_boxes(
    bounding_boxes: torch.Tensor,
1362
    format: tv_tensors.BoundingBoxFormat,
1363
1364
    top: int,
    left: int,
1365
1366
1367
    height: int,
    width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1368

1369
    # Crop or implicit pad if left and/or top have negative values:
1370
    if format == tv_tensors.BoundingBoxFormat.XYXY:
1371
        sub = [left, top, left, top]
1372
    else:
1373
1374
        sub = [left, top, 0, 0]

1375
    bounding_boxes = bounding_boxes - torch.tensor(sub, dtype=bounding_boxes.dtype, device=bounding_boxes.device)
Philip Meier's avatar
Philip Meier committed
1376
    canvas_size = (height, width)
1377

Philip Meier's avatar
Philip Meier committed
1378
    return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
1379
1380


1381
@_register_kernel_internal(crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
1382
def _crop_bounding_boxes_dispatch(
1383
1384
    inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int
) -> tv_tensors.BoundingBoxes:
1385
1386
1387
    output, canvas_size = crop_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width
    )
1388
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
1389
1390


1391
@_register_kernel_internal(crop, tv_tensors.Mask)
1392
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
1393
1394
1395
1396
1397
1398
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
        needs_squeeze = True
    else:
        needs_squeeze = False

1399
    output = crop_image(mask, top, left, height, width)
1400
1401
1402
1403
1404

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1405
1406


1407
@_register_kernel_internal(crop, tv_tensors.Video)
1408
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
1409
    return crop_image(video, top, left, height, width)
1410
1411


1412
def perspective(
1413
    inpt: torch.Tensor,
1414
1415
1416
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1417
    fill: _FillTypeJIT = None,
1418
    coefficients: Optional[List[float]] = None,
1419
) -> torch.Tensor:
1420
    """See :class:`~torchvision.transforms.v2.RandomPerspective` for details."""
1421
    if torch.jit.is_scripting():
1422
        return perspective_image(
1423
1424
1425
1426
1427
1428
            inpt,
            startpoints=startpoints,
            endpoints=endpoints,
            interpolation=interpolation,
            fill=fill,
            coefficients=coefficients,
1429
        )
1430

1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
    _log_api_usage_once(perspective)

    kernel = _get_kernel(perspective, type(inpt))
    return kernel(
        inpt,
        startpoints=startpoints,
        endpoints=endpoints,
        interpolation=interpolation,
        fill=fill,
        coefficients=coefficients,
    )

1443

1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
    # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
    # src/libImaging/Geometry.c#L394

    #
    # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #
    theta1 = torch.tensor(
        [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
    )
    theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)

    d = 0.5
    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
1459
    x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype)
1460
    base_grid[..., 0].copy_(x_grid)
1461
    y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1)
1462
1463
1464
1465
    base_grid[..., 1].copy_(y_grid)
    base_grid[..., 2].fill_(1)

    rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device))
1466
1467
1468
    shape = (1, oh * ow, 3)
    output_grid1 = base_grid.view(shape).bmm(rescaled_theta1)
    output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2))
1469
1470
1471
1472
1473

    output_grid = output_grid1.div_(output_grid2).sub_(1.0)
    return output_grid.view(1, oh, ow, 2)


1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
def _perspective_coefficients(
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]],
) -> List[float]:
    if coefficients is not None:
        if startpoints is not None and endpoints is not None:
            raise ValueError("The startpoints/endpoints and the coefficients shouldn't be defined concurrently.")
        elif len(coefficients) != 8:
            raise ValueError("Argument coefficients should have 8 float values")
        return coefficients
    elif startpoints is not None and endpoints is not None:
        return _get_perspective_coeffs(startpoints, endpoints)
    else:
        raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.")


1491
@_register_kernel_internal(perspective, torch.Tensor)
1492
@_register_kernel_internal(perspective, tv_tensors.Image)
1493
def perspective_image(
1494
    image: torch.Tensor,
1495
1496
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1497
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1498
    fill: _FillTypeJIT = None,
1499
    coefficients: Optional[List[float]] = None,
1500
) -> torch.Tensor:
1501
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1502
1503
    interpolation = _check_interpolation(interpolation)

1504
    _assert_grid_transform_inputs(
1505
1506
1507
1508
1509
1510
1511
1512
        image,
        matrix=None,
        interpolation=interpolation.value,
        fill=fill,
        supported_interpolation_modes=["nearest", "bilinear"],
        coeffs=perspective_coeffs,
    )

1513
    oh, ow = image.shape[-2:]
1514
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1515
    grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
1516
    return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1517
1518


1519
@_register_kernel_internal(perspective, PIL.Image.Image)
1520
def _perspective_image_pil(
1521
    image: PIL.Image.Image,
1522
1523
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1524
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1525
    fill: _FillTypeJIT = None,
1526
    coefficients: Optional[List[float]] = None,
1527
) -> PIL.Image.Image:
1528
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1529
    interpolation = _check_interpolation(interpolation)
1530
    return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
1531
1532


1533
1534
def perspective_bounding_boxes(
    bounding_boxes: torch.Tensor,
1535
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1536
    canvas_size: Tuple[int, int],
1537
1538
1539
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]] = None,
1540
) -> torch.Tensor:
1541
1542
    if bounding_boxes.numel() == 0:
        return bounding_boxes
1543

1544
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1545

1546
    original_shape = bounding_boxes.shape
Nicolas Hug's avatar
Nicolas Hug committed
1547
    # TODO: first cast to float if bbox is int64 before convert_bounding_box_format
1548
    bounding_boxes = (
1549
        convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
1550
    ).reshape(-1, 4)
1551

1552
1553
    dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32
    device = bounding_boxes.device
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584

    # perspective_coeffs are computed as endpoint -> start point
    # We have to invert perspective_coeffs for bboxes:
    # (x, y) - end point and (x_out, y_out) - start point
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # and we would like to get:
    # x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # and compute inv_coeffs in terms of coeffs

    denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3]
    if denom == 0:
        raise RuntimeError(
            f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. "
            f"Denominator is zero, denom={denom}"
        )

    inv_coeffs = [
        (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom,
        (-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom,
        (perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom,
        (-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom,
        (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom,
        (-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom,
        (-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom,
        (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
    ]

1585
1586
    theta1 = torch.tensor(
        [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
1587
1588
1589
1590
        dtype=dtype,
        device=device,
    )

1591
1592
1593
1594
    theta2 = torch.tensor(
        [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device
    )

1595
1596
1597
1598
    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
1599
    points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1600
1601
1602
1603
1604
    points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
    # 2) Now let's transform the points using perspective matrices
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)

1605
1606
    numer_points = torch.matmul(points, theta1.T)
    denom_points = torch.matmul(points, theta2.T)
1607
    transformed_points = numer_points.div_(denom_points)
1608
1609
1610

    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
1611
    transformed_points = transformed_points.reshape(-1, 4, 2)
1612
1613
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)

1614
1615
    out_bboxes = clamp_bounding_boxes(
        torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
1616
        format=tv_tensors.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
1617
        canvas_size=canvas_size,
1618
    )
1619
1620
1621

    # out_bboxes should be of shape [N boxes, 4]

Nicolas Hug's avatar
Nicolas Hug committed
1622
    return convert_bounding_box_format(
1623
        out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
1624
    ).reshape(original_shape)
1625
1626


1627
@_register_kernel_internal(perspective, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
1628
def _perspective_bounding_boxes_dispatch(
1629
    inpt: tv_tensors.BoundingBoxes,
1630
1631
1632
1633
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]] = None,
    **kwargs,
1634
) -> tv_tensors.BoundingBoxes:
1635
1636
1637
1638
1639
1640
1641
1642
    output = perspective_bounding_boxes(
        inpt.as_subclass(torch.Tensor),
        format=inpt.format,
        canvas_size=inpt.canvas_size,
        startpoints=startpoints,
        endpoints=endpoints,
        coefficients=coefficients,
    )
1643
    return tv_tensors.wrap(output, like=inpt)
1644
1645


1646
1647
def perspective_mask(
    mask: torch.Tensor,
1648
1649
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1650
    fill: _FillTypeJIT = None,
1651
    coefficients: Optional[List[float]] = None,
1652
) -> torch.Tensor:
1653
1654
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1655
1656
1657
1658
        needs_squeeze = True
    else:
        needs_squeeze = False

1659
    output = perspective_image(
1660
        mask, startpoints, endpoints, interpolation=InterpolationMode.NEAREST, fill=fill, coefficients=coefficients
1661
    )
1662

1663
1664
1665
1666
1667
    if needs_squeeze:
        output = output.squeeze(0)

    return output

1668

1669
@_register_kernel_internal(perspective, tv_tensors.Mask, tv_tensor_wrapper=False)
1670
def _perspective_mask_dispatch(
1671
    inpt: tv_tensors.Mask,
1672
1673
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1674
    fill: _FillTypeJIT = None,
1675
1676
    coefficients: Optional[List[float]] = None,
    **kwargs,
1677
) -> tv_tensors.Mask:
1678
1679
1680
1681
1682
1683
1684
    output = perspective_mask(
        inpt.as_subclass(torch.Tensor),
        startpoints=startpoints,
        endpoints=endpoints,
        fill=fill,
        coefficients=coefficients,
    )
1685
    return tv_tensors.wrap(output, like=inpt)
1686
1687


1688
@_register_kernel_internal(perspective, tv_tensors.Video)
1689
1690
def perspective_video(
    video: torch.Tensor,
1691
1692
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1693
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1694
    fill: _FillTypeJIT = None,
1695
    coefficients: Optional[List[float]] = None,
1696
) -> torch.Tensor:
1697
    return perspective_image(
1698
1699
        video, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
    )
1700
1701


1702
def elastic(
1703
    inpt: torch.Tensor,
1704
    displacement: torch.Tensor,
1705
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1706
1707
    fill: _FillTypeJIT = None,
) -> torch.Tensor:
1708
    """See :class:`~torchvision.transforms.v2.ElasticTransform` for details."""
1709
    if torch.jit.is_scripting():
1710
        return elastic_image(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
1711
1712
1713
1714
1715

    _log_api_usage_once(elastic)

    kernel = _get_kernel(elastic, type(inpt))
    return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
1716
1717


1718
1719
1720
elastic_transform = elastic


1721
@_register_kernel_internal(elastic, torch.Tensor)
1722
@_register_kernel_internal(elastic, tv_tensors.Image)
1723
def elastic_image(
1724
    image: torch.Tensor,
1725
    displacement: torch.Tensor,
1726
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1727
    fill: _FillTypeJIT = None,
1728
) -> torch.Tensor:
Philip Meier's avatar
Philip Meier committed
1729
1730
1731
    if not isinstance(displacement, torch.Tensor):
        raise TypeError("Argument displacement should be a Tensor")

1732
1733
    interpolation = _check_interpolation(interpolation)

1734
    height, width = image.shape[-2:]
1735
    device = image.device
1736
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1737
1738
1739
1740
1741
1742
1743

    # Patch: elastic transform should support (cpu,f16) input
    is_cpu_half = device.type == "cpu" and dtype == torch.float16
    if is_cpu_half:
        image = image.to(torch.float32)
        dtype = torch.float32

1744
    # We are aware that if input image dtype is uint8 and displacement is float64 then
1745
    # displacement will be cast to float32 and all computations will be done with float32
1746
    # We can fix this later if needed
1747

1748
    expected_shape = (1, height, width, 2)
1749
1750
1751
    if expected_shape != displacement.shape:
        raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")

1752
1753
1754
    grid = _create_identity_grid((height, width), device=device, dtype=dtype).add_(
        displacement.to(dtype=dtype, device=device)
    )
1755
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1756

1757
1758
1759
    if is_cpu_half:
        output = output.to(torch.float16)

1760
    return output
1761
1762


1763
@_register_kernel_internal(elastic, PIL.Image.Image)
1764
def _elastic_image_pil(
1765
    image: PIL.Image.Image,
1766
    displacement: torch.Tensor,
1767
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1768
    fill: _FillTypeJIT = None,
1769
) -> PIL.Image.Image:
1770
    t_img = pil_to_tensor(image)
1771
    output = elastic_image(t_img, displacement, interpolation=interpolation, fill=fill)
1772
    return to_pil_image(output, mode=image.mode)
1773
1774


1775
def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
1776
    sy, sx = size
1777
1778
    base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype)
    x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype)
1779
1780
    base_grid[..., 0].copy_(x_grid)

1781
    y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1)
1782
1783
1784
1785
1786
    base_grid[..., 1].copy_(y_grid)

    return base_grid


1787
1788
def elastic_bounding_boxes(
    bounding_boxes: torch.Tensor,
1789
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1790
    canvas_size: Tuple[int, int],
1791
1792
    displacement: torch.Tensor,
) -> torch.Tensor:
Philip Meier's avatar
Philip Meier committed
1793
1794
1795
1796
1797
1798
    expected_shape = (1, canvas_size[0], canvas_size[1], 2)
    if not isinstance(displacement, torch.Tensor):
        raise TypeError("Argument displacement should be a Tensor")
    elif displacement.shape != expected_shape:
        raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")

1799
1800
    if bounding_boxes.numel() == 0:
        return bounding_boxes
1801

1802
    # TODO: add in docstring about approximation we are doing for grid inversion
1803
1804
    device = bounding_boxes.device
    dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32
1805
1806
1807

    if displacement.dtype != dtype or displacement.device != device:
        displacement = displacement.to(dtype=dtype, device=device)
1808

1809
    original_shape = bounding_boxes.shape
Nicolas Hug's avatar
Nicolas Hug committed
1810
    # TODO: first cast to float if bbox is int64 before convert_bounding_box_format
1811
    bounding_boxes = (
1812
        convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
1813
    ).reshape(-1, 4)
1814

Philip Meier's avatar
Philip Meier committed
1815
    id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
1816
1817
    # We construct an approximation of inverse grid as inv_grid = id_grid - displacement
    # This is not an exact inverse of the grid
1818
    inv_grid = id_grid.sub_(displacement)
1819
1820

    # Get points from bboxes
1821
    points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1822
1823
1824
1825
1826
    if points.is_floating_point():
        points = points.ceil_()
    index_xy = points.to(dtype=torch.long)
    index_x, index_y = index_xy[:, 0], index_xy[:, 1]

1827
    # Transform points:
Philip Meier's avatar
Philip Meier committed
1828
    t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype)
1829
    transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
1830

1831
    transformed_points = transformed_points.reshape(-1, 4, 2)
1832
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
1833
1834
    out_bboxes = clamp_bounding_boxes(
        torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
1835
        format=tv_tensors.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
1836
        canvas_size=canvas_size,
1837
    )
1838

Nicolas Hug's avatar
Nicolas Hug committed
1839
    return convert_bounding_box_format(
1840
        out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
1841
    ).reshape(original_shape)
1842
1843


1844
@_register_kernel_internal(elastic, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
1845
def _elastic_bounding_boxes_dispatch(
1846
1847
    inpt: tv_tensors.BoundingBoxes, displacement: torch.Tensor, **kwargs
) -> tv_tensors.BoundingBoxes:
1848
1849
1850
    output = elastic_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, displacement=displacement
    )
1851
    return tv_tensors.wrap(output, like=inpt)
1852
1853


1854
1855
1856
def elastic_mask(
    mask: torch.Tensor,
    displacement: torch.Tensor,
1857
    fill: _FillTypeJIT = None,
1858
) -> torch.Tensor:
1859
1860
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1861
1862
1863
1864
        needs_squeeze = True
    else:
        needs_squeeze = False

1865
    output = elastic_image(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill)
1866
1867
1868
1869
1870

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1871
1872


1873
@_register_kernel_internal(elastic, tv_tensors.Mask, tv_tensor_wrapper=False)
1874
def _elastic_mask_dispatch(
1875
1876
    inpt: tv_tensors.Mask, displacement: torch.Tensor, fill: _FillTypeJIT = None, **kwargs
) -> tv_tensors.Mask:
1877
    output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill)
1878
    return tv_tensors.wrap(output, like=inpt)
1879
1880


1881
@_register_kernel_internal(elastic, tv_tensors.Video)
1882
1883
1884
def elastic_video(
    video: torch.Tensor,
    displacement: torch.Tensor,
1885
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1886
    fill: _FillTypeJIT = None,
1887
) -> torch.Tensor:
1888
    return elastic_image(video, displacement, interpolation=interpolation, fill=fill)
1889
1890


1891
def center_crop(inpt: torch.Tensor, output_size: List[int]) -> torch.Tensor:
1892
    """See :class:`~torchvision.transforms.v2.RandomCrop` for details."""
1893
    if torch.jit.is_scripting():
1894
        return center_crop_image(inpt, output_size=output_size)
1895
1896
1897
1898
1899

    _log_api_usage_once(center_crop)

    kernel = _get_kernel(center_crop, type(inpt))
    return kernel(inpt, output_size=output_size)
1900
1901


1902
1903
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
    if isinstance(output_size, numbers.Number):
1904
1905
        s = int(output_size)
        return [s, s]
1906
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
1907
        return [output_size[0], output_size[0]]
1908
1909
    else:
        return list(output_size)
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928


def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]:
    return [
        (crop_width - image_width) // 2 if crop_width > image_width else 0,
        (crop_height - image_height) // 2 if crop_height > image_height else 0,
        (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
        (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
    ]


def _center_crop_compute_crop_anchor(
    crop_height: int, crop_width: int, image_height: int, image_width: int
) -> Tuple[int, int]:
    crop_top = int(round((image_height - crop_height) / 2.0))
    crop_left = int(round((image_width - crop_width) / 2.0))
    return crop_top, crop_left


1929
@_register_kernel_internal(center_crop, torch.Tensor)
1930
@_register_kernel_internal(center_crop, tv_tensors.Image)
1931
def center_crop_image(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
1932
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1933
1934
1935
1936
    shape = image.shape
    if image.numel() == 0:
        return image.reshape(shape[:-2] + (crop_height, crop_width))
    image_height, image_width = shape[-2:]
1937
1938
1939

    if crop_height > image_height or crop_width > image_width:
        padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
1940
        image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0)
1941

1942
        image_height, image_width = image.shape[-2:]
1943
        if crop_width == image_width and crop_height == image_height:
1944
            return image
1945
1946

    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
1947
    return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]
1948
1949


1950
@_register_kernel_internal(center_crop, PIL.Image.Image)
1951
def _center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
1952
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1953
    image_height, image_width = _get_size_image_pil(image)
1954
1955
1956

    if crop_height > image_height or crop_width > image_width:
        padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
1957
        image = _pad_image_pil(image, padding_ltrb, fill=0)
1958

1959
        image_height, image_width = _get_size_image_pil(image)
1960
        if crop_width == image_width and crop_height == image_height:
1961
            return image
1962
1963

    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
1964
    return _crop_image_pil(image, crop_top, crop_left, crop_height, crop_width)
1965
1966


1967
1968
def center_crop_bounding_boxes(
    bounding_boxes: torch.Tensor,
1969
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1970
    canvas_size: Tuple[int, int],
1971
    output_size: List[int],
1972
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1973
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
Philip Meier's avatar
Philip Meier committed
1974
    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size)
1975
1976
1977
    return crop_bounding_boxes(
        bounding_boxes, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width
    )
1978
1979


1980
@_register_kernel_internal(center_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
1981
def _center_crop_bounding_boxes_dispatch(
1982
1983
    inpt: tv_tensors.BoundingBoxes, output_size: List[int]
) -> tv_tensors.BoundingBoxes:
1984
1985
1986
    output, canvas_size = center_crop_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size
    )
1987
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
1988
1989


1990
@_register_kernel_internal(center_crop, tv_tensors.Mask)
1991
1992
1993
def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1994
1995
1996
1997
        needs_squeeze = True
    else:
        needs_squeeze = False

1998
    output = center_crop_image(image=mask, output_size=output_size)
1999
2000
2001
2002
2003

    if needs_squeeze:
        output = output.squeeze(0)

    return output
2004
2005


2006
@_register_kernel_internal(center_crop, tv_tensors.Video)
2007
def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor:
2008
    return center_crop_image(video, output_size)
2009
2010


2011
def resized_crop(
2012
    inpt: torch.Tensor,
2013
2014
2015
2016
2017
2018
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2019
    antialias: Optional[bool] = True,
2020
) -> torch.Tensor:
2021
    """See :class:`~torchvision.transforms.v2.RandomResizedCrop` for details."""
2022
    if torch.jit.is_scripting():
2023
        return resized_crop_image(
2024
2025
2026
2027
2028
2029
2030
2031
            inpt,
            top=top,
            left=left,
            height=height,
            width=width,
            size=size,
            interpolation=interpolation,
            antialias=antialias,
2032
        )
2033

2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
    _log_api_usage_once(resized_crop)

    kernel = _get_kernel(resized_crop, type(inpt))
    return kernel(
        inpt,
        top=top,
        left=left,
        height=height,
        width=width,
        size=size,
        interpolation=interpolation,
        antialias=antialias,
    )
2047

2048
2049

@_register_kernel_internal(resized_crop, torch.Tensor)
2050
@_register_kernel_internal(resized_crop, tv_tensors.Image)
2051
def resized_crop_image(
2052
    image: torch.Tensor,
2053
2054
2055
2056
2057
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
2058
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2059
    antialias: Optional[bool] = True,
2060
) -> torch.Tensor:
2061
2062
    image = crop_image(image, top, left, height, width)
    return resize_image(image, size, interpolation=interpolation, antialias=antialias)
2063
2064


2065
def _resized_crop_image_pil(
2066
    image: PIL.Image.Image,
2067
2068
2069
2070
2071
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
2072
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2073
) -> PIL.Image.Image:
2074
2075
    image = _crop_image_pil(image, top, left, height, width)
    return _resize_image_pil(image, size, interpolation=interpolation)
2076
2077


2078
@_register_kernel_internal(resized_crop, PIL.Image.Image)
2079
def _resized_crop_image_pil_dispatch(
2080
2081
2082
2083
2084
2085
2086
    image: PIL.Image.Image,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2087
    antialias: Optional[bool] = True,
2088
2089
2090
) -> PIL.Image.Image:
    if antialias is False:
        warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
2091
    return _resized_crop_image_pil(
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
        image,
        top=top,
        left=left,
        height=height,
        width=width,
        size=size,
        interpolation=interpolation,
    )


2102
2103
def resized_crop_bounding_boxes(
    bounding_boxes: torch.Tensor,
2104
    format: tv_tensors.BoundingBoxFormat,
2105
2106
2107
2108
2109
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
2110
) -> Tuple[torch.Tensor, Tuple[int, int]]:
2111
2112
2113
2114
    bounding_boxes, canvas_size = crop_bounding_boxes(bounding_boxes, format, top, left, height, width)
    return resize_bounding_boxes(bounding_boxes, canvas_size=canvas_size, size=size)


2115
@_register_kernel_internal(resized_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
2116
def _resized_crop_bounding_boxes_dispatch(
2117
2118
    inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: List[int], **kwargs
) -> tv_tensors.BoundingBoxes:
2119
2120
2121
    output, canvas_size = resized_crop_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size
    )
2122
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
2123
2124


2125
def resized_crop_mask(
2126
2127
2128
2129
2130
2131
2132
    mask: torch.Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
) -> torch.Tensor:
2133
2134
    mask = crop_mask(mask, top, left, height, width)
    return resize_mask(mask, size)
2135
2136


2137
@_register_kernel_internal(resized_crop, tv_tensors.Mask, tv_tensor_wrapper=False)
2138
def _resized_crop_mask_dispatch(
2139
2140
    inpt: tv_tensors.Mask, top: int, left: int, height: int, width: int, size: List[int], **kwargs
) -> tv_tensors.Mask:
2141
2142
2143
    output = resized_crop_mask(
        inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size
    )
2144
    return tv_tensors.wrap(output, like=inpt)
2145
2146


2147
@_register_kernel_internal(resized_crop, tv_tensors.Video)
2148
2149
2150
2151
2152
2153
2154
def resized_crop_video(
    video: torch.Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
2155
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2156
    antialias: Optional[bool] = True,
2157
) -> torch.Tensor:
2158
    return resized_crop_image(
2159
2160
2161
2162
        video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
    )


2163
def five_crop(
2164
2165
    inpt: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
2166
    """See :class:`~torchvision.transforms.v2.FiveCrop` for details."""
2167
    if torch.jit.is_scripting():
2168
        return five_crop_image(inpt, size=size)
2169
2170
2171
2172
2173

    _log_api_usage_once(five_crop)

    kernel = _get_kernel(five_crop, type(inpt))
    return kernel(inpt, size=size)
2174
2175


2176
2177
def _parse_five_crop_size(size: List[int]) -> List[int]:
    if isinstance(size, numbers.Number):
2178
2179
        s = int(size)
        size = [s, s]
2180
    elif isinstance(size, (tuple, list)) and len(size) == 1:
2181
2182
        s = size[0]
        size = [s, s]
2183
2184
2185
2186
2187
2188
2189

    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")

    return size


2190
@_register_five_ten_crop_kernel_internal(five_crop, torch.Tensor)
2191
@_register_five_ten_crop_kernel_internal(five_crop, tv_tensors.Image)
2192
def five_crop_image(
2193
    image: torch.Tensor, size: List[int]
2194
2195
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    crop_height, crop_width = _parse_five_crop_size(size)
2196
    image_height, image_width = image.shape[-2:]
2197
2198

    if crop_width > image_width or crop_height > image_height:
2199
        raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
2200

2201
2202
2203
2204
2205
    tl = crop_image(image, 0, 0, crop_height, crop_width)
    tr = crop_image(image, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop_image(image, image_height - crop_height, 0, crop_height, crop_width)
    br = crop_image(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
    center = center_crop_image(image, [crop_height, crop_width])
2206
2207
2208
2209

    return tl, tr, bl, br, center


2210
@_register_five_ten_crop_kernel_internal(five_crop, PIL.Image.Image)
2211
def _five_crop_image_pil(
2212
    image: PIL.Image.Image, size: List[int]
2213
2214
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
    crop_height, crop_width = _parse_five_crop_size(size)
2215
    image_height, image_width = _get_size_image_pil(image)
2216
2217

    if crop_width > image_width or crop_height > image_height:
2218
        raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
2219

2220
2221
2222
2223
2224
    tl = _crop_image_pil(image, 0, 0, crop_height, crop_width)
    tr = _crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)
    bl = _crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width)
    br = _crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
    center = _center_crop_image_pil(image, [crop_height, crop_width])
2225
2226
2227
2228

    return tl, tr, bl, br, center


2229
@_register_five_ten_crop_kernel_internal(five_crop, tv_tensors.Video)
2230
2231
2232
def five_crop_video(
    video: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
2233
    return five_crop_image(video, size)
2234
2235


2236
def ten_crop(
2237
    inpt: torch.Tensor, size: List[int], vertical_flip: bool = False
2238
) -> Tuple[
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
2249
]:
2250
    """See :class:`~torchvision.transforms.v2.TenCrop` for details."""
2251
    if torch.jit.is_scripting():
2252
        return ten_crop_image(inpt, size=size, vertical_flip=vertical_flip)
2253
2254
2255
2256
2257

    _log_api_usage_once(ten_crop)

    kernel = _get_kernel(ten_crop, type(inpt))
    return kernel(inpt, size=size, vertical_flip=vertical_flip)
2258
2259


2260
@_register_five_ten_crop_kernel_internal(ten_crop, torch.Tensor)
2261
@_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Image)
2262
def ten_crop_image(
Philip Meier's avatar
Philip Meier committed
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
    image: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
2276
    non_flipped = five_crop_image(image, size)
2277
2278

    if vertical_flip:
2279
        image = vertical_flip_image(image)
2280
    else:
2281
        image = horizontal_flip_image(image)
2282

2283
    flipped = five_crop_image(image, size)
2284

Philip Meier's avatar
Philip Meier committed
2285
    return non_flipped + flipped
2286
2287


2288
@_register_five_ten_crop_kernel_internal(ten_crop, PIL.Image.Image)
2289
def _ten_crop_image_pil(
Philip Meier's avatar
Philip Meier committed
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
    image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
) -> Tuple[
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
]:
2303
    non_flipped = _five_crop_image_pil(image, size)
2304
2305

    if vertical_flip:
2306
        image = _vertical_flip_image_pil(image)
2307
    else:
2308
        image = _horizontal_flip_image_pil(image)
2309

2310
    flipped = _five_crop_image_pil(image, size)
Philip Meier's avatar
Philip Meier committed
2311
2312
2313
2314

    return non_flipped + flipped


2315
@_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Video)
Philip Meier's avatar
Philip Meier committed
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
def ten_crop_video(
    video: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
2330
    return ten_crop_image(video, size, vertical_flip=vertical_flip)