_geometry.py 83.5 KB
Newer Older
1
import math
2
import numbers
3
import warnings
4
from typing import Any, List, Optional, Sequence, Tuple, Union
5
6
7

import PIL.Image
import torch
8
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
9

10
from torchvision import tv_tensors
11
12
from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _pad_symmetric
13
from torchvision.transforms.functional import (
14
    _compute_resized_output_size as __compute_resized_output_size,
15
    _get_perspective_coeffs,
16
    _interpolation_modes_from_int,
17
    InterpolationMode,
18
    pil_modes_mapping,
19
20
    pil_to_tensor,
    to_pil_image,
21
)
22

23
24
from torchvision.utils import _log_api_usage_once

Nicolas Hug's avatar
Nicolas Hug committed
25
from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format
26

27
from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal
28

29

30
31
32
33
34
35
36
37
38
39
40
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
    if isinstance(interpolation, int):
        interpolation = _interpolation_modes_from_int(interpolation)
    elif not isinstance(interpolation, InterpolationMode):
        raise ValueError(
            f"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, "
            f"but got {interpolation}."
        )
    return interpolation


41
def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
42
    """[BETA] See :class:`~torchvision.transforms.v2.RandomHorizontalFlip` for details."""
43
    if torch.jit.is_scripting():
44
        return horizontal_flip_image(inpt)
45
46
47
48
49

    _log_api_usage_once(horizontal_flip)

    kernel = _get_kernel(horizontal_flip, type(inpt))
    return kernel(inpt)
50
51


52
@_register_kernel_internal(horizontal_flip, torch.Tensor)
53
@_register_kernel_internal(horizontal_flip, tv_tensors.Image)
54
def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor:
55
56
57
    return image.flip(-1)


58
@_register_kernel_internal(horizontal_flip, PIL.Image.Image)
59
def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
60
    return _FP.hflip(image)
61
62


63
@_register_kernel_internal(horizontal_flip, tv_tensors.Mask)
64
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
65
    return horizontal_flip_image(mask)
66
67


68
def horizontal_flip_bounding_boxes(
69
    bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: Tuple[int, int]
70
) -> torch.Tensor:
71
    shape = bounding_boxes.shape
72

73
    bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
74

75
    if format == tv_tensors.BoundingBoxFormat.XYXY:
Philip Meier's avatar
Philip Meier committed
76
        bounding_boxes[:, [2, 0]] = bounding_boxes[:, [0, 2]].sub_(canvas_size[1]).neg_()
77
    elif format == tv_tensors.BoundingBoxFormat.XYWH:
Philip Meier's avatar
Philip Meier committed
78
        bounding_boxes[:, 0].add_(bounding_boxes[:, 2]).sub_(canvas_size[1]).neg_()
79
    else:  # format == tv_tensors.BoundingBoxFormat.CXCYWH:
Philip Meier's avatar
Philip Meier committed
80
        bounding_boxes[:, 0].sub_(canvas_size[1]).neg_()
81

82
    return bounding_boxes.reshape(shape)
83
84


85
86
@_register_kernel_internal(horizontal_flip, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _horizontal_flip_bounding_boxes_dispatch(inpt: tv_tensors.BoundingBoxes) -> tv_tensors.BoundingBoxes:
87
88
89
    output = horizontal_flip_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
    )
90
    return tv_tensors.wrap(output, like=inpt)
91
92


93
@_register_kernel_internal(horizontal_flip, tv_tensors.Video)
94
def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
95
    return horizontal_flip_image(video)
96
97


98
def vertical_flip(inpt: torch.Tensor) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
99
    """[BETA] See :class:`~torchvision.transforms.v2.RandomVerticalFlip` for details."""
100
    if torch.jit.is_scripting():
101
        return vertical_flip_image(inpt)
102
103
104
105
106

    _log_api_usage_once(vertical_flip)

    kernel = _get_kernel(vertical_flip, type(inpt))
    return kernel(inpt)
107
108


109
@_register_kernel_internal(vertical_flip, torch.Tensor)
110
@_register_kernel_internal(vertical_flip, tv_tensors.Image)
111
def vertical_flip_image(image: torch.Tensor) -> torch.Tensor:
112
113
114
    return image.flip(-2)


115
@_register_kernel_internal(vertical_flip, PIL.Image.Image)
116
def _vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
Philip Meier's avatar
Philip Meier committed
117
    return _FP.vflip(image)
118
119


120
@_register_kernel_internal(vertical_flip, tv_tensors.Mask)
121
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
122
    return vertical_flip_image(mask)
123
124


125
def vertical_flip_bounding_boxes(
126
    bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: Tuple[int, int]
127
) -> torch.Tensor:
128
    shape = bounding_boxes.shape
129

130
    bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
131

132
    if format == tv_tensors.BoundingBoxFormat.XYXY:
Philip Meier's avatar
Philip Meier committed
133
        bounding_boxes[:, [1, 3]] = bounding_boxes[:, [3, 1]].sub_(canvas_size[0]).neg_()
134
    elif format == tv_tensors.BoundingBoxFormat.XYWH:
Philip Meier's avatar
Philip Meier committed
135
        bounding_boxes[:, 1].add_(bounding_boxes[:, 3]).sub_(canvas_size[0]).neg_()
136
    else:  # format == tv_tensors.BoundingBoxFormat.CXCYWH:
Philip Meier's avatar
Philip Meier committed
137
        bounding_boxes[:, 1].sub_(canvas_size[0]).neg_()
138

139
    return bounding_boxes.reshape(shape)
140
141


142
143
@_register_kernel_internal(vertical_flip, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _vertical_flip_bounding_boxes_dispatch(inpt: tv_tensors.BoundingBoxes) -> tv_tensors.BoundingBoxes:
144
145
146
    output = vertical_flip_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
    )
147
    return tv_tensors.wrap(output, like=inpt)
148

149

150
@_register_kernel_internal(vertical_flip, tv_tensors.Video)
151
def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
152
    return vertical_flip_image(video)
153
154


155
156
157
158
159
160
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
hflip = horizontal_flip
vflip = vertical_flip


161
def _compute_resized_output_size(
Philip Meier's avatar
Philip Meier committed
162
    canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
163
164
165
) -> List[int]:
    if isinstance(size, int):
        size = [size]
166
167
168
169
170
    elif max_size is not None and len(size) != 1:
        raise ValueError(
            "max_size should only be passed if size specifies the length of the smaller edge, "
            "i.e. size should be an int or a sequence of length 1 in torchscript mode."
        )
Philip Meier's avatar
Philip Meier committed
171
    return __compute_resized_output_size(canvas_size, size=size, max_size=max_size)
172
173


174
def resize(
175
    inpt: torch.Tensor,
176
177
178
    size: List[int],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
    max_size: Optional[int] = None,
179
    antialias: Optional[bool] = True,
180
) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
181
    """[BETA] See :class:`~torchvision.transforms.v2.Resize` for details."""
182
    if torch.jit.is_scripting():
183
        return resize_image(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
184
185
186
187
188

    _log_api_usage_once(resize)

    kernel = _get_kernel(resize, type(inpt))
    return kernel(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
189
190


191
@_register_kernel_internal(resize, torch.Tensor)
192
@_register_kernel_internal(resize, tv_tensors.Image)
193
def resize_image(
194
195
    image: torch.Tensor,
    size: List[int],
196
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
197
    max_size: Optional[int] = None,
198
    antialias: Optional[bool] = True,
199
) -> torch.Tensor:
200
    interpolation = _check_interpolation(interpolation)
201
    antialias = False if antialias is None else antialias
202
203
204
    align_corners: Optional[bool] = None
    if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC:
        align_corners = False
205
    else:
206
        # The default of antialias is True from 0.17, so we don't warn or
207
208
        # error if other interpolation modes are used. This is documented.
        antialias = False
209

210
    shape = image.shape
211
    numel = image.numel()
212
    num_channels, old_height, old_width = shape[-3:]
vfdev's avatar
vfdev committed
213
    new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
214

215
216
    if (new_height, new_width) == (old_height, old_width):
        return image
217
    elif numel > 0:
218
        image = image.reshape(-1, num_channels, old_height, old_width)
219

220
        dtype = image.dtype
221
222
223
224
        acceptable_dtypes = [torch.float32, torch.float64]
        if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
            # uint8 dtype can be included for cpu and cuda input if nearest mode
            acceptable_dtypes.append(torch.uint8)
225
226
227
228
229
230
231
        elif image.device.type == "cpu":
            # uint8 dtype support for bilinear and bicubic is limited to cpu and
            # according to our benchmarks, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilinear
            if (interpolation == InterpolationMode.BILINEAR and "AVX2" in torch.backends.cpu.get_cpu_capability()) or (
                interpolation == InterpolationMode.BICUBIC
            ):
                acceptable_dtypes.append(torch.uint8)
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

        strides = image.stride()
        if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
            # There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as
            # contiguous even though the input is un-ambiguously channels_last (https://github.com/pytorch/pytorch/issues/68430).
            # In particular this happens for the typical torchvision use-case of single CHW images where we fake the batch dim
            # to become 1CHW. Below, we restride those tensors to trick torch core into properly allocating the output as
            # channels_last, thus preserving the memory format of the input. This is not just for format consistency:
            # for uint8 bilinear images, this also avoids an extra copy (re-packing) of the output and saves time.
            # TODO: when https://github.com/pytorch/pytorch/issues/68430 is fixed (possibly by https://github.com/pytorch/pytorch/pull/100373),
            # we should be able to remove this hack.
            new_strides = list(strides)
            new_strides[0] = numel
            image = image.as_strided((1, num_channels, old_height, old_width), new_strides)

        need_cast = dtype not in acceptable_dtypes
248
249
250
251
        if need_cast:
            image = image.to(dtype=torch.float32)

        image = interpolate(
252
253
            image,
            size=[new_height, new_width],
254
255
            mode=interpolation.value,
            align_corners=align_corners,
256
257
            antialias=antialias,
        )
258

259
260
        if need_cast:
            if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
261
                # This path is hit on non-AVX archs, or on GPU.
262
                image = image.clamp_(min=0, max=255)
263
264
265
            if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
                image = image.round_()
            image = image.to(dtype=dtype)
266

267
    return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
268
269


270
def _resize_image_pil(
271
    image: PIL.Image.Image,
272
    size: Union[Sequence[int], int],
273
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
274
275
    max_size: Optional[int] = None,
) -> PIL.Image.Image:
276
277
278
279
280
281
282
    old_height, old_width = image.height, image.width
    new_height, new_width = _compute_resized_output_size(
        (old_height, old_width),
        size=size,  # type: ignore[arg-type]
        max_size=max_size,
    )

283
    interpolation = _check_interpolation(interpolation)
284
285
286
287
288

    if (new_height, new_width) == (old_height, old_width):
        return image

    return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation])
289
290


291
@_register_kernel_internal(resize, PIL.Image.Image)
292
def __resize_image_pil_dispatch(
293
294
295
296
    image: PIL.Image.Image,
    size: Union[Sequence[int], int],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
    max_size: Optional[int] = None,
297
    antialias: Optional[bool] = True,
298
299
300
) -> PIL.Image.Image:
    if antialias is False:
        warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
301
    return _resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size)
302
303


304
305
306
def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
307
308
309
310
        needs_squeeze = True
    else:
        needs_squeeze = False

311
    output = resize_image(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
312
313
314
315
316

    if needs_squeeze:
        output = output.squeeze(0)

    return output
317
318


319
@_register_kernel_internal(resize, tv_tensors.Mask, tv_tensor_wrapper=False)
320
def _resize_mask_dispatch(
321
322
    inpt: tv_tensors.Mask, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> tv_tensors.Mask:
323
    output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size)
324
    return tv_tensors.wrap(output, like=inpt)
325
326


327
def resize_bounding_boxes(
Philip Meier's avatar
Philip Meier committed
328
    bounding_boxes: torch.Tensor, canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
329
) -> Tuple[torch.Tensor, Tuple[int, int]]:
Philip Meier's avatar
Philip Meier committed
330
331
    old_height, old_width = canvas_size
    new_height, new_width = _compute_resized_output_size(canvas_size, size=size, max_size=max_size)
332
333

    if (new_height, new_width) == (old_height, old_width):
Philip Meier's avatar
Philip Meier committed
334
        return bounding_boxes, canvas_size
335

336
337
    w_ratio = new_width / old_width
    h_ratio = new_height / old_height
338
    ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_boxes.device)
339
    return (
340
        bounding_boxes.mul(ratios).to(bounding_boxes.dtype),
341
342
        (new_height, new_width),
    )
343
344


345
@_register_kernel_internal(resize, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
346
def _resize_bounding_boxes_dispatch(
347
348
    inpt: tv_tensors.BoundingBoxes, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> tv_tensors.BoundingBoxes:
349
350
351
    output, canvas_size = resize_bounding_boxes(
        inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size
    )
352
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
353
354


355
@_register_kernel_internal(resize, tv_tensors.Video)
356
357
358
def resize_video(
    video: torch.Tensor,
    size: List[int],
359
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
360
    max_size: Optional[int] = None,
361
    antialias: Optional[bool] = True,
362
) -> torch.Tensor:
363
    return resize_image(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
364
365


366
def affine(
367
    inpt: torch.Tensor,
368
369
370
371
372
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
373
    fill: _FillTypeJIT = None,
374
    center: Optional[List[float]] = None,
375
) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
376
    """[BETA] See :class:`~torchvision.transforms.v2.RandomAffine` for details."""
377
    if torch.jit.is_scripting():
378
        return affine_image(
379
            inpt,
380
            angle=angle,
381
382
383
384
385
386
387
            translate=translate,
            scale=scale,
            shear=shear,
            interpolation=interpolation,
            fill=fill,
            center=center,
        )
388
389
390
391
392
393
394
395
396
397
398
399
400
401

    _log_api_usage_once(affine)

    kernel = _get_kernel(affine, type(inpt))
    return kernel(
        inpt,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=interpolation,
        fill=fill,
        center=center,
    )
402
403


404
def _affine_parse_args(
405
    angle: Union[int, float],
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    translate: List[float],
    scale: float,
    shear: List[float],
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    center: Optional[List[float]] = None,
) -> Tuple[float, List[float], List[float], Optional[List[float]]]:
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if not isinstance(translate, (list, tuple)):
        raise TypeError("Argument translate should be a sequence")

    if len(translate) != 2:
        raise ValueError("Argument translate should be a sequence of length 2")

    if scale <= 0.0:
        raise ValueError("Argument scale should be positive")

    if not isinstance(shear, (numbers.Number, (list, tuple))):
        raise TypeError("Shear should be either a single value or a sequence of two values")

    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")

    if isinstance(angle, int):
        angle = float(angle)

    if isinstance(translate, tuple):
        translate = list(translate)

    if isinstance(shear, numbers.Number):
        shear = [shear, 0.0]

    if isinstance(shear, tuple):
        shear = list(shear)

    if len(shear) == 1:
        shear = [shear[0], shear[0]]

    if len(shear) != 2:
        raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")

448
449
450
451
452
    if center is not None:
        if not isinstance(center, (list, tuple)):
            raise TypeError("Argument center should be a sequence")
        else:
            center = [float(c) for c in center]
453
454
455
456

    return angle, translate, shear, center


457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
def _get_inverse_affine_matrix(
    center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
) -> List[float]:
    # Helper method to compute inverse matrix for affine transformation

    # Pillow requires inverse affine transformation matrix:
    # Affine matrix is : M = T * C * RotateScaleShear * C^-1
    #
    # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
    #       C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
    #       RotateScaleShear is rotation with scale and shear matrix
    #
    #       RotateScaleShear(a, s, (sx, sy)) =
    #       = R(a) * S(s) * SHy(sy) * SHx(sx)
    #       = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
    #         [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
    #         [ 0                    , 0                                      , 1 ]
    # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
    # SHx(s) = [1, -tan(s)] and SHy(s) = [1      , 0]
    #          [0, 1      ]              [-tan(s), 1]
    #
    # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1

    rot = math.radians(angle)
    sx = math.radians(shear[0])
    sy = math.radians(shear[1])

    cx, cy = center
    tx, ty = translate

    # Cached results
    cos_sy = math.cos(sy)
    tan_sx = math.tan(sx)
    rot_minus_sy = rot - sy
    cx_plus_tx = cx + tx
    cy_plus_ty = cy + ty

    # Rotate Scale Shear (RSS) without scaling
    a = math.cos(rot_minus_sy) / cos_sy
    b = -(a * tan_sx + math.sin(rot))
    c = math.sin(rot_minus_sy) / cos_sy
    d = math.cos(rot) - c * tan_sx

    if inverted:
        # Inverted rotation matrix with scale and shear
        # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
        matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0]
        # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
        # and then apply center translation: C * RSS^-1 * C^-1 * T^-1
        matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty
        matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty
    else:
        matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0]
        # Apply inverse of center translation: RSS * C^-1
        # and then apply translation and center : T * C * RSS * C^-1
        matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy
        matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy

    return matrix


def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
    # Inspired of PIL implementation:
    # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054

    # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
    # Points are shifted due to affine matrix torch convention about
    # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
    half_w = 0.5 * w
    half_h = 0.5 * h
    pts = torch.tensor(
        [
            [-half_w, -half_h, 1.0],
            [-half_w, half_h, 1.0],
            [half_w, half_h, 1.0],
            [half_w, -half_h, 1.0],
        ]
    )
    theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
    new_pts = torch.matmul(pts, theta.T)
    min_vals, max_vals = new_pts.aminmax(dim=0)

    # shift points to [0, w] and [0, h] interval to match PIL results
    halfs = torch.tensor((half_w, half_h))
    min_vals.add_(halfs)
    max_vals.add_(halfs)

    # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
    tol = 1e-4
    inv_tol = 1.0 / tol
    cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_()
    cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_()
    size = cmax.sub_(cmin)
    return int(size[0]), int(size[1])  # w, h


553
def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill: _FillTypeJIT) -> torch.Tensor:
554
555
556
557
558
559
560
561
562
563
    input_shape = img.shape
    output_height, output_width = grid.shape[1], grid.shape[2]
    num_channels, input_height, input_width = input_shape[-3:]
    output_shape = input_shape[:-3] + (num_channels, output_height, output_width)

    if img.numel() == 0:
        return img.reshape(output_shape)

    img = img.reshape(-1, num_channels, input_height, input_width)
    squashed_batch_size = img.shape[0]
564

565
566
567
568
    # We are using context knowledge that grid should have float dtype
    fp = img.dtype == grid.dtype
    float_img = img if fp else img.to(grid.dtype)

569
    if squashed_batch_size > 1:
570
        # Apply same grid to a batch of images
571
        grid = grid.expand(squashed_batch_size, -1, -1, -1)
572
573
574

    # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
    if fill is not None:
575
576
577
        mask = torch.ones(
            (squashed_batch_size, 1, input_height, input_width), dtype=float_img.dtype, device=float_img.device
        )
578
579
580
581
582
583
584
585
        float_img = torch.cat((float_img, mask), dim=1)

    float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False)

    # Fill with required color
    if fill is not None:
        float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
        mask = mask.expand_as(float_img)
586
        fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)]  # type: ignore[arg-type]
587
588
589
590
591
592
593
594
595
        fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
        if mode == "nearest":
            bool_mask = mask < 0.5
            float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
        else:  # 'bilinear'
            # The following is mathematically equivalent to:
            # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
            float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)

596
597
    img = float_img.round_().to(img.dtype) if not fp else float_img

598
    return img.reshape(output_shape)
599
600
601
602
603
604


def _assert_grid_transform_inputs(
    image: torch.Tensor,
    matrix: Optional[List[float]],
    interpolation: str,
605
    fill: _FillTypeJIT,
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
    supported_interpolation_modes: List[str],
    coeffs: Optional[List[float]] = None,
) -> None:
    if matrix is not None:
        if not isinstance(matrix, list):
            raise TypeError("Argument matrix should be a list")
        elif len(matrix) != 6:
            raise ValueError("Argument matrix should have 6 float values")

    if coeffs is not None and len(coeffs) != 8:
        raise ValueError("Argument coeffs should have 8 float values")

    if fill is not None:
        if isinstance(fill, (tuple, list)):
            length = len(fill)
            num_channels = image.shape[-3]
            if length > 1 and length != num_channels:
                raise ValueError(
                    "The number of elements in 'fill' cannot broadcast to match the number of "
                    f"channels of the image ({length} != {num_channels})"
                )
        elif not isinstance(fill, (int, float)):
            raise ValueError("Argument fill should be either int, float, tuple or list")

    if interpolation not in supported_interpolation_modes:
        raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")


def _affine_grid(
    theta: torch.Tensor,
    w: int,
    h: int,
    ow: int,
    oh: int,
) -> torch.Tensor:
    # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
    # AffineGridGenerator.cpp#L18
    # Difference with AffineGridGenerator is that:
    # 1) we normalize grid values after applying theta
    # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
    dtype = theta.dtype
    device = theta.device

    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
    x_grid = torch.linspace((1.0 - ow) * 0.5, (ow - 1.0) * 0.5, steps=ow, device=device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace((1.0 - oh) * 0.5, (oh - 1.0) * 0.5, steps=oh, device=device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
    base_grid[..., 2].fill_(1)

    rescaled_theta = theta.transpose(1, 2).div_(torch.tensor([0.5 * w, 0.5 * h], dtype=dtype, device=device))
    output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
    return output_grid.view(1, oh, ow, 2)


661
@_register_kernel_internal(affine, torch.Tensor)
662
@_register_kernel_internal(affine, tv_tensors.Image)
663
def affine_image(
664
    image: torch.Tensor,
665
    angle: Union[int, float],
666
667
668
    translate: List[float],
    scale: float,
    shear: List[float],
669
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
670
    fill: _FillTypeJIT = None,
671
672
    center: Optional[List[float]] = None,
) -> torch.Tensor:
673
674
    interpolation = _check_interpolation(interpolation)

675
676
    angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

677
678
    height, width = image.shape[-2:]

679
680
681
    center_f = [0.0, 0.0]
    if center is not None:
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
682
        center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
683

684
    translate_f = [float(t) for t in translate]
685
686
    matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)

687
688
    _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])

689
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
690
691
    theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
    grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
692
    return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
693
694


695
@_register_kernel_internal(affine, PIL.Image.Image)
696
def _affine_image_pil(
697
    image: PIL.Image.Image,
698
    angle: Union[int, float],
699
700
701
    translate: List[float],
    scale: float,
    shear: List[float],
702
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
703
    fill: _FillTypeJIT = None,
704
705
    center: Optional[List[float]] = None,
) -> PIL.Image.Image:
706
    interpolation = _check_interpolation(interpolation)
707
708
709
710
711
712
    angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

    # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
    # it is visually better to estimate the center without 0.5 offset
    # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
    if center is None:
713
        height, width = _get_size_image_pil(image)
714
715
716
        center = [width * 0.5, height * 0.5]
    matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)

717
    return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
718
719


720
721
def _affine_bounding_boxes_with_expand(
    bounding_boxes: torch.Tensor,
722
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
723
    canvas_size: Tuple[int, int],
724
725
726
727
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
728
    center: Optional[List[float]] = None,
729
    expand: bool = False,
730
) -> Tuple[torch.Tensor, Tuple[int, int]]:
731
    if bounding_boxes.numel() == 0:
Philip Meier's avatar
Philip Meier committed
732
        return bounding_boxes, canvas_size
733
734
735
736
737
738
739

    original_shape = bounding_boxes.shape
    original_dtype = bounding_boxes.dtype
    bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
    dtype = bounding_boxes.dtype
    device = bounding_boxes.device
    bounding_boxes = (
Nicolas Hug's avatar
Nicolas Hug committed
740
        convert_bounding_box_format(
741
            bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True
742
743
744
        )
    ).reshape(-1, 4)

745
746
747
    angle, translate, shear, center = _affine_parse_args(
        angle, translate, scale, shear, InterpolationMode.NEAREST, center
    )
748

749
    if center is None:
Philip Meier's avatar
Philip Meier committed
750
        height, width = canvas_size
751
752
        center = [width * 0.5, height * 0.5]

753
754
755
756
757
758
759
    affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
    transposed_affine_matrix = (
        torch.tensor(
            affine_vector,
            dtype=dtype,
            device=device,
        )
760
        .reshape(2, 3)
761
762
        .T
    )
763
764
765
766
    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
767
    points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
768
    points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
769
    # 2) Now let's transform the points using affine matrix
770
    transformed_points = torch.matmul(points, transposed_affine_matrix)
771
772
    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
773
    transformed_points = transformed_points.reshape(-1, 4, 2)
774
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
775
    out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
776
777
778
779

    if expand:
        # Compute minimum point for transformed image frame:
        # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
Philip Meier's avatar
Philip Meier committed
780
        height, width = canvas_size
781
782
783
        points = torch.tensor(
            [
                [0.0, 0.0, 1.0],
784
785
786
                [0.0, float(height), 1.0],
                [float(width), float(height), 1.0],
                [float(width), 0.0, 1.0],
787
788
789
790
            ],
            dtype=dtype,
            device=device,
        )
791
        new_points = torch.matmul(points, transposed_affine_matrix)
792
        tr = torch.amin(new_points, dim=0, keepdim=True)
793
        # Translate bounding boxes
794
        out_bboxes.sub_(tr.repeat((1, 2)))
795
796
        # Estimate meta-data for image with inverted=True
        affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
797
        new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
Philip Meier's avatar
Philip Meier committed
798
        canvas_size = (new_height, new_width)
799

800
    out_bboxes = clamp_bounding_boxes(out_bboxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size)
Nicolas Hug's avatar
Nicolas Hug committed
801
    out_bboxes = convert_bounding_box_format(
802
        out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
803
804
805
    ).reshape(original_shape)

    out_bboxes = out_bboxes.to(original_dtype)
Philip Meier's avatar
Philip Meier committed
806
    return out_bboxes, canvas_size
807
808


809
810
def affine_bounding_boxes(
    bounding_boxes: torch.Tensor,
811
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
812
    canvas_size: Tuple[int, int],
813
    angle: Union[int, float],
814
815
816
817
818
    translate: List[float],
    scale: float,
    shear: List[float],
    center: Optional[List[float]] = None,
) -> torch.Tensor:
819
820
    out_box, _ = _affine_bounding_boxes_with_expand(
        bounding_boxes,
821
        format=format,
Philip Meier's avatar
Philip Meier committed
822
        canvas_size=canvas_size,
823
824
825
826
827
828
829
830
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        center=center,
        expand=False,
    )
    return out_box
831
832


833
@_register_kernel_internal(affine, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
834
def _affine_bounding_boxes_dispatch(
835
    inpt: tv_tensors.BoundingBoxes,
836
837
838
839
840
841
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
    center: Optional[List[float]] = None,
    **kwargs,
842
) -> tv_tensors.BoundingBoxes:
843
844
845
846
847
848
849
850
851
852
    output = affine_bounding_boxes(
        inpt.as_subclass(torch.Tensor),
        format=inpt.format,
        canvas_size=inpt.canvas_size,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        center=center,
    )
853
    return tv_tensors.wrap(output, like=inpt)
854
855


856
857
def affine_mask(
    mask: torch.Tensor,
858
    angle: Union[int, float],
859
860
861
    translate: List[float],
    scale: float,
    shear: List[float],
862
    fill: _FillTypeJIT = None,
863
864
    center: Optional[List[float]] = None,
) -> torch.Tensor:
865
866
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
867
868
869
870
        needs_squeeze = True
    else:
        needs_squeeze = False

871
    output = affine_image(
872
        mask,
873
874
875
876
877
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=InterpolationMode.NEAREST,
878
        fill=fill,
879
880
881
        center=center,
    )

882
883
884
885
886
    if needs_squeeze:
        output = output.squeeze(0)

    return output

887

888
@_register_kernel_internal(affine, tv_tensors.Mask, tv_tensor_wrapper=False)
889
def _affine_mask_dispatch(
890
    inpt: tv_tensors.Mask,
891
892
893
894
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
895
    fill: _FillTypeJIT = None,
896
897
    center: Optional[List[float]] = None,
    **kwargs,
898
) -> tv_tensors.Mask:
899
900
901
902
903
904
905
906
907
    output = affine_mask(
        inpt.as_subclass(torch.Tensor),
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        fill=fill,
        center=center,
    )
908
    return tv_tensors.wrap(output, like=inpt)
909
910


911
@_register_kernel_internal(affine, tv_tensors.Video)
912
913
914
915
916
917
def affine_video(
    video: torch.Tensor,
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
918
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
919
    fill: _FillTypeJIT = None,
920
921
    center: Optional[List[float]] = None,
) -> torch.Tensor:
922
    return affine_image(
923
924
925
926
927
928
929
930
931
932
933
        video,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=interpolation,
        fill=fill,
        center=center,
    )


934
def rotate(
935
    inpt: torch.Tensor,
936
    angle: float,
937
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
938
    expand: bool = False,
939
    center: Optional[List[float]] = None,
940
941
    fill: _FillTypeJIT = None,
) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
942
    """[BETA] See :class:`~torchvision.transforms.v2.RandomRotation` for details."""
943
    if torch.jit.is_scripting():
944
        return rotate_image(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
945

946
    _log_api_usage_once(rotate)
947

948
949
950
951
952
    kernel = _get_kernel(rotate, type(inpt))
    return kernel(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center)


@_register_kernel_internal(rotate, torch.Tensor)
953
@_register_kernel_internal(rotate, tv_tensors.Image)
954
def rotate_image(
955
    image: torch.Tensor,
956
    angle: float,
957
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
958
959
    expand: bool = False,
    center: Optional[List[float]] = None,
960
    fill: _FillTypeJIT = None,
961
) -> torch.Tensor:
962
963
    interpolation = _check_interpolation(interpolation)

964
    input_height, input_width = image.shape[-2:]
965

966
967
    center_f = [0.0, 0.0]
    if center is not None:
968
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
969
        center_f = [(c - s * 0.5) for c, s in zip(center, [input_width, input_height])]
970
971
972
973

    # due to current incoherence of rotation angle direction between affine and rotate implementations
    # we need to set -angle.
    matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
974

975
    _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
976

977
978
979
980
981
982
983
    output_width, output_height = (
        _compute_affine_output_size(matrix, input_width, input_height) if expand else (input_width, input_height)
    )
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
    theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
    grid = _affine_grid(theta, w=input_width, h=input_height, ow=output_width, oh=output_height)
    return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
984
985


986
@_register_kernel_internal(rotate, PIL.Image.Image)
987
def _rotate_image_pil(
988
    image: PIL.Image.Image,
989
    angle: float,
990
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
991
992
    expand: bool = False,
    center: Optional[List[float]] = None,
993
    fill: _FillTypeJIT = None,
994
) -> PIL.Image.Image:
995
996
    interpolation = _check_interpolation(interpolation)

997
    return _FP.rotate(
998
        image, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
999
1000
1001
    )


1002
1003
def rotate_bounding_boxes(
    bounding_boxes: torch.Tensor,
1004
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1005
    canvas_size: Tuple[int, int],
1006
1007
1008
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
1009
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1010
1011
    return _affine_bounding_boxes_with_expand(
        bounding_boxes,
1012
        format=format,
Philip Meier's avatar
Philip Meier committed
1013
        canvas_size=canvas_size,
1014
1015
1016
1017
1018
1019
1020
        angle=-angle,
        translate=[0.0, 0.0],
        scale=1.0,
        shear=[0.0, 0.0],
        center=center,
        expand=expand,
    )
1021
1022


1023
@_register_kernel_internal(rotate, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
1024
def _rotate_bounding_boxes_dispatch(
1025
1026
    inpt: tv_tensors.BoundingBoxes, angle: float, expand: bool = False, center: Optional[List[float]] = None, **kwargs
) -> tv_tensors.BoundingBoxes:
1027
1028
1029
1030
1031
1032
1033
1034
    output, canvas_size = rotate_bounding_boxes(
        inpt.as_subclass(torch.Tensor),
        format=inpt.format,
        canvas_size=inpt.canvas_size,
        angle=angle,
        expand=expand,
        center=center,
    )
1035
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
1036
1037


1038
1039
def rotate_mask(
    mask: torch.Tensor,
1040
1041
1042
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
1043
    fill: _FillTypeJIT = None,
1044
) -> torch.Tensor:
1045
1046
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1047
1048
1049
1050
        needs_squeeze = True
    else:
        needs_squeeze = False

1051
    output = rotate_image(
1052
        mask,
1053
1054
1055
        angle=angle,
        expand=expand,
        interpolation=InterpolationMode.NEAREST,
1056
        fill=fill,
1057
1058
1059
        center=center,
    )

1060
1061
1062
1063
1064
    if needs_squeeze:
        output = output.squeeze(0)

    return output

1065

1066
@_register_kernel_internal(rotate, tv_tensors.Mask, tv_tensor_wrapper=False)
1067
def _rotate_mask_dispatch(
1068
    inpt: tv_tensors.Mask,
1069
1070
1071
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
1072
    fill: _FillTypeJIT = None,
1073
    **kwargs,
1074
) -> tv_tensors.Mask:
1075
    output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center)
1076
    return tv_tensors.wrap(output, like=inpt)
1077
1078


1079
@_register_kernel_internal(rotate, tv_tensors.Video)
1080
1081
1082
def rotate_video(
    video: torch.Tensor,
    angle: float,
1083
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
1084
1085
    expand: bool = False,
    center: Optional[List[float]] = None,
1086
    fill: _FillTypeJIT = None,
1087
) -> torch.Tensor:
1088
    return rotate_image(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
1089
1090


1091
def pad(
1092
    inpt: torch.Tensor,
1093
1094
1095
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
    padding_mode: str = "constant",
1096
) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
1097
    """[BETA] See :class:`~torchvision.transforms.v2.Pad` for details."""
1098
    if torch.jit.is_scripting():
1099
        return pad_image(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
1100

1101
    _log_api_usage_once(pad)
1102

1103
1104
    kernel = _get_kernel(pad, type(inpt))
    return kernel(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
1105
1106


1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
    if isinstance(padding, int):
        pad_left = pad_right = pad_top = pad_bottom = padding
    elif isinstance(padding, (tuple, list)):
        if len(padding) == 1:
            pad_left = pad_right = pad_top = pad_bottom = padding[0]
        elif len(padding) == 2:
            pad_left = pad_right = padding[0]
            pad_top = pad_bottom = padding[1]
        elif len(padding) == 4:
            pad_left = padding[0]
            pad_top = padding[1]
            pad_right = padding[2]
            pad_bottom = padding[3]
        else:
            raise ValueError(
                f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
            )
    else:
        raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}")

    return [pad_left, pad_right, pad_top, pad_bottom]
1129

1130

1131
@_register_kernel_internal(pad, torch.Tensor)
1132
@_register_kernel_internal(pad, tv_tensors.Image)
1133
def pad_image(
1134
    image: torch.Tensor,
1135
1136
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1137
1138
    padding_mode: str = "constant",
) -> torch.Tensor:
1139
    # Be aware that while `padding` has order `[left, top, right, bottom]`, `torch_padding` uses
1140
1141
1142
1143
    # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
    # internally.
    torch_padding = _parse_pad_padding(padding)

1144
    if padding_mode not in ("constant", "edge", "reflect", "symmetric"):
1145
1146
1147
1148
1149
        raise ValueError(
            f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
            f"but got `'{padding_mode}'`."
        )

1150
    if fill is None:
1151
1152
1153
1154
1155
1156
        fill = 0

    if isinstance(fill, (int, float)):
        return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
    elif len(fill) == 1:
        return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode)
1157
    else:
1158
        return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
1159
1160
1161


def _pad_with_scalar_fill(
1162
    image: torch.Tensor,
1163
1164
1165
    torch_padding: List[int],
    fill: Union[int, float],
    padding_mode: str,
1166
) -> torch.Tensor:
1167
1168
    shape = image.shape
    num_channels, height, width = shape[-3:]
1169

1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
    batch_size = 1
    for s in shape[:-3]:
        batch_size *= s

    image = image.reshape(batch_size, num_channels, height, width)

    if padding_mode == "edge":
        # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
        # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
        # name.
        padding_mode = "replicate"

    if padding_mode == "constant":
        image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
    elif padding_mode in ("reflect", "replicate"):
        # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
        # TODO: See https://github.com/pytorch/pytorch/issues/40763
        dtype = image.dtype
        if not image.is_floating_point():
            needs_cast = True
            image = image.to(torch.float32)
        else:
            needs_cast = False
1193

1194
1195
1196
1197
1198
        image = torch_pad(image, torch_padding, mode=padding_mode)

        if needs_cast:
            image = image.to(dtype)
    else:  # padding_mode == "symmetric"
1199
        image = _pad_symmetric(image, torch_padding)
1200
1201

    new_height, new_width = image.shape[-2:]
1202

1203
    return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
1204
1205


1206
# TODO: This should be removed once torch_pad supports non-scalar padding values
1207
def _pad_with_vector_fill(
1208
    image: torch.Tensor,
1209
    torch_padding: List[int],
1210
    fill: List[float],
1211
    padding_mode: str,
1212
1213
1214
1215
) -> torch.Tensor:
    if padding_mode != "constant":
        raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")

1216
1217
    output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
    left, right, top, bottom = torch_padding
1218
1219
1220
1221
1222

    # We are creating the tensor in the autodetected dtype first and convert to the right one after to avoid an implicit
    # float -> int conversion. That happens for example for the valid input of a uint8 image with floating point fill
    # value.
    fill = torch.tensor(fill, device=image.device).to(dtype=image.dtype).reshape(-1, 1, 1)
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234

    if top > 0:
        output[..., :top, :] = fill
    if left > 0:
        output[..., :, :left] = fill
    if bottom > 0:
        output[..., -bottom:, :] = fill
    if right > 0:
        output[..., :, -right:] = fill
    return output


1235
_pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad)
1236
1237


1238
@_register_kernel_internal(pad, tv_tensors.Mask)
1239
1240
def pad_mask(
    mask: torch.Tensor,
1241
1242
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1243
1244
    padding_mode: str = "constant",
) -> torch.Tensor:
1245
1246
1247
    if fill is None:
        fill = 0

1248
    if isinstance(fill, (tuple, list)):
1249
1250
        raise ValueError("Non-scalar fill value is not supported")

1251
1252
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1253
1254
1255
1256
        needs_squeeze = True
    else:
        needs_squeeze = False

1257
    output = pad_image(mask, padding=padding, fill=fill, padding_mode=padding_mode)
1258
1259
1260
1261
1262

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1263
1264


1265
1266
def pad_bounding_boxes(
    bounding_boxes: torch.Tensor,
1267
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1268
    canvas_size: Tuple[int, int],
1269
    padding: List[int],
vfdev's avatar
vfdev committed
1270
    padding_mode: str = "constant",
1271
) -> Tuple[torch.Tensor, Tuple[int, int]]:
vfdev's avatar
vfdev committed
1272
1273
1274
1275
    if padding_mode not in ["constant"]:
        # TODO: add support of other padding modes
        raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")

1276
    left, right, top, bottom = _parse_pad_padding(padding)
1277

1278
    if format == tv_tensors.BoundingBoxFormat.XYXY:
1279
1280
1281
        pad = [left, top, left, top]
    else:
        pad = [left, top, 0, 0]
1282
    bounding_boxes = bounding_boxes + torch.tensor(pad, dtype=bounding_boxes.dtype, device=bounding_boxes.device)
1283

Philip Meier's avatar
Philip Meier committed
1284
    height, width = canvas_size
1285
1286
    height += top + bottom
    width += left + right
Philip Meier's avatar
Philip Meier committed
1287
    canvas_size = (height, width)
1288

Philip Meier's avatar
Philip Meier committed
1289
    return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
1290
1291


1292
@_register_kernel_internal(pad, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
1293
def _pad_bounding_boxes_dispatch(
1294
1295
    inpt: tv_tensors.BoundingBoxes, padding: List[int], padding_mode: str = "constant", **kwargs
) -> tv_tensors.BoundingBoxes:
1296
1297
1298
1299
1300
1301
1302
    output, canvas_size = pad_bounding_boxes(
        inpt.as_subclass(torch.Tensor),
        format=inpt.format,
        canvas_size=inpt.canvas_size,
        padding=padding,
        padding_mode=padding_mode,
    )
1303
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
1304
1305


1306
@_register_kernel_internal(pad, tv_tensors.Video)
1307
1308
def pad_video(
    video: torch.Tensor,
1309
1310
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1311
1312
    padding_mode: str = "constant",
) -> torch.Tensor:
1313
    return pad_image(video, padding, fill=fill, padding_mode=padding_mode)
1314
1315


1316
def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
1317
    """[BETA] See :class:`~torchvision.transforms.v2.RandomCrop` for details."""
1318
    if torch.jit.is_scripting():
1319
        return crop_image(inpt, top=top, left=left, height=height, width=width)
1320
1321

    _log_api_usage_once(crop)
1322

1323
1324
    kernel = _get_kernel(crop, type(inpt))
    return kernel(inpt, top=top, left=left, height=height, width=width)
1325

1326
1327

@_register_kernel_internal(crop, torch.Tensor)
1328
@_register_kernel_internal(crop, tv_tensors.Image)
1329
def crop_image(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
    h, w = image.shape[-2:]

    right = left + width
    bottom = top + height

    if left < 0 or top < 0 or right > w or bottom > h:
        image = image[..., max(top, 0) : bottom, max(left, 0) : right]
        torch_padding = [
            max(min(right, 0) - left, 0),
            max(right - max(w, left), 0),
            max(min(bottom, 0) - top, 0),
            max(bottom - max(h, top), 0),
        ]
        return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
    return image[..., top:bottom, left:right]


1347
1348
_crop_image_pil = _FP.crop
_register_kernel_internal(crop, PIL.Image.Image)(_crop_image_pil)
1349
1350


1351
1352
def crop_bounding_boxes(
    bounding_boxes: torch.Tensor,
1353
    format: tv_tensors.BoundingBoxFormat,
1354
1355
    top: int,
    left: int,
1356
1357
1358
    height: int,
    width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1359

1360
    # Crop or implicit pad if left and/or top have negative values:
1361
    if format == tv_tensors.BoundingBoxFormat.XYXY:
1362
        sub = [left, top, left, top]
1363
    else:
1364
1365
        sub = [left, top, 0, 0]

1366
    bounding_boxes = bounding_boxes - torch.tensor(sub, dtype=bounding_boxes.dtype, device=bounding_boxes.device)
Philip Meier's avatar
Philip Meier committed
1367
    canvas_size = (height, width)
1368

Philip Meier's avatar
Philip Meier committed
1369
    return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
1370
1371


1372
@_register_kernel_internal(crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
1373
def _crop_bounding_boxes_dispatch(
1374
1375
    inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int
) -> tv_tensors.BoundingBoxes:
1376
1377
1378
    output, canvas_size = crop_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width
    )
1379
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
1380
1381


1382
@_register_kernel_internal(crop, tv_tensors.Mask)
1383
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
1384
1385
1386
1387
1388
1389
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
        needs_squeeze = True
    else:
        needs_squeeze = False

1390
    output = crop_image(mask, top, left, height, width)
1391
1392
1393
1394
1395

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1396
1397


1398
@_register_kernel_internal(crop, tv_tensors.Video)
1399
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
1400
    return crop_image(video, top, left, height, width)
1401
1402


1403
def perspective(
1404
    inpt: torch.Tensor,
1405
1406
1407
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1408
    fill: _FillTypeJIT = None,
1409
    coefficients: Optional[List[float]] = None,
1410
) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
1411
    """[BETA] See :class:`~torchvision.transforms.v2.RandomPerspective` for details."""
1412
    if torch.jit.is_scripting():
1413
        return perspective_image(
1414
1415
1416
1417
1418
1419
            inpt,
            startpoints=startpoints,
            endpoints=endpoints,
            interpolation=interpolation,
            fill=fill,
            coefficients=coefficients,
1420
        )
1421

1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
    _log_api_usage_once(perspective)

    kernel = _get_kernel(perspective, type(inpt))
    return kernel(
        inpt,
        startpoints=startpoints,
        endpoints=endpoints,
        interpolation=interpolation,
        fill=fill,
        coefficients=coefficients,
    )

1434

1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
    # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
    # src/libImaging/Geometry.c#L394

    #
    # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #
    theta1 = torch.tensor(
        [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
    )
    theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)

    d = 0.5
    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
1450
    x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype)
1451
    base_grid[..., 0].copy_(x_grid)
1452
    y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1)
1453
1454
1455
1456
    base_grid[..., 1].copy_(y_grid)
    base_grid[..., 2].fill_(1)

    rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device))
1457
1458
1459
    shape = (1, oh * ow, 3)
    output_grid1 = base_grid.view(shape).bmm(rescaled_theta1)
    output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2))
1460
1461
1462
1463
1464

    output_grid = output_grid1.div_(output_grid2).sub_(1.0)
    return output_grid.view(1, oh, ow, 2)


1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
def _perspective_coefficients(
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]],
) -> List[float]:
    if coefficients is not None:
        if startpoints is not None and endpoints is not None:
            raise ValueError("The startpoints/endpoints and the coefficients shouldn't be defined concurrently.")
        elif len(coefficients) != 8:
            raise ValueError("Argument coefficients should have 8 float values")
        return coefficients
    elif startpoints is not None and endpoints is not None:
        return _get_perspective_coeffs(startpoints, endpoints)
    else:
        raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.")


1482
@_register_kernel_internal(perspective, torch.Tensor)
1483
@_register_kernel_internal(perspective, tv_tensors.Image)
1484
def perspective_image(
1485
    image: torch.Tensor,
1486
1487
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1488
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1489
    fill: _FillTypeJIT = None,
1490
    coefficients: Optional[List[float]] = None,
1491
) -> torch.Tensor:
1492
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1493
1494
    interpolation = _check_interpolation(interpolation)

1495
    _assert_grid_transform_inputs(
1496
1497
1498
1499
1500
1501
1502
1503
        image,
        matrix=None,
        interpolation=interpolation.value,
        fill=fill,
        supported_interpolation_modes=["nearest", "bilinear"],
        coeffs=perspective_coeffs,
    )

1504
    oh, ow = image.shape[-2:]
1505
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1506
    grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
1507
    return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1508
1509


1510
@_register_kernel_internal(perspective, PIL.Image.Image)
1511
def _perspective_image_pil(
1512
    image: PIL.Image.Image,
1513
1514
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1515
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1516
    fill: _FillTypeJIT = None,
1517
    coefficients: Optional[List[float]] = None,
1518
) -> PIL.Image.Image:
1519
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1520
    interpolation = _check_interpolation(interpolation)
1521
    return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
1522
1523


1524
1525
def perspective_bounding_boxes(
    bounding_boxes: torch.Tensor,
1526
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1527
    canvas_size: Tuple[int, int],
1528
1529
1530
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]] = None,
1531
) -> torch.Tensor:
1532
1533
    if bounding_boxes.numel() == 0:
        return bounding_boxes
1534

1535
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1536

1537
    original_shape = bounding_boxes.shape
Nicolas Hug's avatar
Nicolas Hug committed
1538
    # TODO: first cast to float if bbox is int64 before convert_bounding_box_format
1539
    bounding_boxes = (
1540
        convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
1541
    ).reshape(-1, 4)
1542

1543
1544
    dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32
    device = bounding_boxes.device
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575

    # perspective_coeffs are computed as endpoint -> start point
    # We have to invert perspective_coeffs for bboxes:
    # (x, y) - end point and (x_out, y_out) - start point
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # and we would like to get:
    # x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # and compute inv_coeffs in terms of coeffs

    denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3]
    if denom == 0:
        raise RuntimeError(
            f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. "
            f"Denominator is zero, denom={denom}"
        )

    inv_coeffs = [
        (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom,
        (-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom,
        (perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom,
        (-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom,
        (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom,
        (-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom,
        (-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom,
        (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
    ]

1576
1577
    theta1 = torch.tensor(
        [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
1578
1579
1580
1581
        dtype=dtype,
        device=device,
    )

1582
1583
1584
1585
    theta2 = torch.tensor(
        [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device
    )

1586
1587
1588
1589
    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
1590
    points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1591
1592
1593
1594
1595
    points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
    # 2) Now let's transform the points using perspective matrices
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)

1596
1597
    numer_points = torch.matmul(points, theta1.T)
    denom_points = torch.matmul(points, theta2.T)
1598
    transformed_points = numer_points.div_(denom_points)
1599
1600
1601

    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
1602
    transformed_points = transformed_points.reshape(-1, 4, 2)
1603
1604
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)

1605
1606
    out_bboxes = clamp_bounding_boxes(
        torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
1607
        format=tv_tensors.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
1608
        canvas_size=canvas_size,
1609
    )
1610
1611
1612

    # out_bboxes should be of shape [N boxes, 4]

Nicolas Hug's avatar
Nicolas Hug committed
1613
    return convert_bounding_box_format(
1614
        out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
1615
    ).reshape(original_shape)
1616
1617


1618
@_register_kernel_internal(perspective, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
1619
def _perspective_bounding_boxes_dispatch(
1620
    inpt: tv_tensors.BoundingBoxes,
1621
1622
1623
1624
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]] = None,
    **kwargs,
1625
) -> tv_tensors.BoundingBoxes:
1626
1627
1628
1629
1630
1631
1632
1633
    output = perspective_bounding_boxes(
        inpt.as_subclass(torch.Tensor),
        format=inpt.format,
        canvas_size=inpt.canvas_size,
        startpoints=startpoints,
        endpoints=endpoints,
        coefficients=coefficients,
    )
1634
    return tv_tensors.wrap(output, like=inpt)
1635
1636


1637
1638
def perspective_mask(
    mask: torch.Tensor,
1639
1640
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1641
    fill: _FillTypeJIT = None,
1642
    coefficients: Optional[List[float]] = None,
1643
) -> torch.Tensor:
1644
1645
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1646
1647
1648
1649
        needs_squeeze = True
    else:
        needs_squeeze = False

1650
    output = perspective_image(
1651
        mask, startpoints, endpoints, interpolation=InterpolationMode.NEAREST, fill=fill, coefficients=coefficients
1652
    )
1653

1654
1655
1656
1657
1658
    if needs_squeeze:
        output = output.squeeze(0)

    return output

1659

1660
@_register_kernel_internal(perspective, tv_tensors.Mask, tv_tensor_wrapper=False)
1661
def _perspective_mask_dispatch(
1662
    inpt: tv_tensors.Mask,
1663
1664
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1665
    fill: _FillTypeJIT = None,
1666
1667
    coefficients: Optional[List[float]] = None,
    **kwargs,
1668
) -> tv_tensors.Mask:
1669
1670
1671
1672
1673
1674
1675
    output = perspective_mask(
        inpt.as_subclass(torch.Tensor),
        startpoints=startpoints,
        endpoints=endpoints,
        fill=fill,
        coefficients=coefficients,
    )
1676
    return tv_tensors.wrap(output, like=inpt)
1677
1678


1679
@_register_kernel_internal(perspective, tv_tensors.Video)
1680
1681
def perspective_video(
    video: torch.Tensor,
1682
1683
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1684
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1685
    fill: _FillTypeJIT = None,
1686
    coefficients: Optional[List[float]] = None,
1687
) -> torch.Tensor:
1688
    return perspective_image(
1689
1690
        video, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
    )
1691
1692


1693
def elastic(
1694
    inpt: torch.Tensor,
1695
    displacement: torch.Tensor,
1696
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1697
1698
    fill: _FillTypeJIT = None,
) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
1699
    """[BETA] See :class:`~torchvision.transforms.v2.ElasticTransform` for details."""
1700
    if torch.jit.is_scripting():
1701
        return elastic_image(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
1702
1703
1704
1705
1706

    _log_api_usage_once(elastic)

    kernel = _get_kernel(elastic, type(inpt))
    return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
1707
1708


1709
1710
1711
elastic_transform = elastic


1712
@_register_kernel_internal(elastic, torch.Tensor)
1713
@_register_kernel_internal(elastic, tv_tensors.Image)
1714
def elastic_image(
1715
    image: torch.Tensor,
1716
    displacement: torch.Tensor,
1717
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1718
    fill: _FillTypeJIT = None,
1719
) -> torch.Tensor:
Philip Meier's avatar
Philip Meier committed
1720
1721
1722
    if not isinstance(displacement, torch.Tensor):
        raise TypeError("Argument displacement should be a Tensor")

1723
1724
    interpolation = _check_interpolation(interpolation)

1725
    height, width = image.shape[-2:]
1726
    device = image.device
1727
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1728
1729
1730
1731
1732
1733
1734

    # Patch: elastic transform should support (cpu,f16) input
    is_cpu_half = device.type == "cpu" and dtype == torch.float16
    if is_cpu_half:
        image = image.to(torch.float32)
        dtype = torch.float32

1735
    # We are aware that if input image dtype is uint8 and displacement is float64 then
1736
    # displacement will be cast to float32 and all computations will be done with float32
1737
    # We can fix this later if needed
1738

1739
    expected_shape = (1, height, width, 2)
1740
1741
1742
    if expected_shape != displacement.shape:
        raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")

1743
1744
1745
    grid = _create_identity_grid((height, width), device=device, dtype=dtype).add_(
        displacement.to(dtype=dtype, device=device)
    )
1746
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1747

1748
1749
1750
    if is_cpu_half:
        output = output.to(torch.float16)

1751
    return output
1752
1753


1754
@_register_kernel_internal(elastic, PIL.Image.Image)
1755
def _elastic_image_pil(
1756
    image: PIL.Image.Image,
1757
    displacement: torch.Tensor,
1758
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1759
    fill: _FillTypeJIT = None,
1760
) -> PIL.Image.Image:
1761
    t_img = pil_to_tensor(image)
1762
    output = elastic_image(t_img, displacement, interpolation=interpolation, fill=fill)
1763
    return to_pil_image(output, mode=image.mode)
1764
1765


1766
def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
1767
    sy, sx = size
1768
1769
    base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype)
    x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype)
1770
1771
    base_grid[..., 0].copy_(x_grid)

1772
    y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1)
1773
1774
1775
1776
1777
    base_grid[..., 1].copy_(y_grid)

    return base_grid


1778
1779
def elastic_bounding_boxes(
    bounding_boxes: torch.Tensor,
1780
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1781
    canvas_size: Tuple[int, int],
1782
1783
    displacement: torch.Tensor,
) -> torch.Tensor:
Philip Meier's avatar
Philip Meier committed
1784
1785
1786
1787
1788
1789
    expected_shape = (1, canvas_size[0], canvas_size[1], 2)
    if not isinstance(displacement, torch.Tensor):
        raise TypeError("Argument displacement should be a Tensor")
    elif displacement.shape != expected_shape:
        raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")

1790
1791
    if bounding_boxes.numel() == 0:
        return bounding_boxes
1792

1793
    # TODO: add in docstring about approximation we are doing for grid inversion
1794
1795
    device = bounding_boxes.device
    dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32
1796
1797
1798

    if displacement.dtype != dtype or displacement.device != device:
        displacement = displacement.to(dtype=dtype, device=device)
1799

1800
    original_shape = bounding_boxes.shape
Nicolas Hug's avatar
Nicolas Hug committed
1801
    # TODO: first cast to float if bbox is int64 before convert_bounding_box_format
1802
    bounding_boxes = (
1803
        convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
1804
    ).reshape(-1, 4)
1805

Philip Meier's avatar
Philip Meier committed
1806
    id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
1807
1808
    # We construct an approximation of inverse grid as inv_grid = id_grid - displacement
    # This is not an exact inverse of the grid
1809
    inv_grid = id_grid.sub_(displacement)
1810
1811

    # Get points from bboxes
1812
    points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1813
1814
1815
1816
1817
    if points.is_floating_point():
        points = points.ceil_()
    index_xy = points.to(dtype=torch.long)
    index_x, index_y = index_xy[:, 0], index_xy[:, 1]

1818
    # Transform points:
Philip Meier's avatar
Philip Meier committed
1819
    t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype)
1820
    transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
1821

1822
    transformed_points = transformed_points.reshape(-1, 4, 2)
1823
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
1824
1825
    out_bboxes = clamp_bounding_boxes(
        torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
1826
        format=tv_tensors.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
1827
        canvas_size=canvas_size,
1828
    )
1829

Nicolas Hug's avatar
Nicolas Hug committed
1830
    return convert_bounding_box_format(
1831
        out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
1832
    ).reshape(original_shape)
1833
1834


1835
@_register_kernel_internal(elastic, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
1836
def _elastic_bounding_boxes_dispatch(
1837
1838
    inpt: tv_tensors.BoundingBoxes, displacement: torch.Tensor, **kwargs
) -> tv_tensors.BoundingBoxes:
1839
1840
1841
    output = elastic_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, displacement=displacement
    )
1842
    return tv_tensors.wrap(output, like=inpt)
1843
1844


1845
1846
1847
def elastic_mask(
    mask: torch.Tensor,
    displacement: torch.Tensor,
1848
    fill: _FillTypeJIT = None,
1849
) -> torch.Tensor:
1850
1851
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1852
1853
1854
1855
        needs_squeeze = True
    else:
        needs_squeeze = False

1856
    output = elastic_image(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill)
1857
1858
1859
1860
1861

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1862
1863


1864
@_register_kernel_internal(elastic, tv_tensors.Mask, tv_tensor_wrapper=False)
1865
def _elastic_mask_dispatch(
1866
1867
    inpt: tv_tensors.Mask, displacement: torch.Tensor, fill: _FillTypeJIT = None, **kwargs
) -> tv_tensors.Mask:
1868
    output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill)
1869
    return tv_tensors.wrap(output, like=inpt)
1870
1871


1872
@_register_kernel_internal(elastic, tv_tensors.Video)
1873
1874
1875
def elastic_video(
    video: torch.Tensor,
    displacement: torch.Tensor,
1876
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1877
    fill: _FillTypeJIT = None,
1878
) -> torch.Tensor:
1879
    return elastic_image(video, displacement, interpolation=interpolation, fill=fill)
1880
1881


1882
def center_crop(inpt: torch.Tensor, output_size: List[int]) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
1883
    """[BETA] See :class:`~torchvision.transforms.v2.RandomCrop` for details."""
1884
    if torch.jit.is_scripting():
1885
        return center_crop_image(inpt, output_size=output_size)
1886
1887
1888
1889
1890

    _log_api_usage_once(center_crop)

    kernel = _get_kernel(center_crop, type(inpt))
    return kernel(inpt, output_size=output_size)
1891
1892


1893
1894
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
    if isinstance(output_size, numbers.Number):
1895
1896
        s = int(output_size)
        return [s, s]
1897
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
1898
        return [output_size[0], output_size[0]]
1899
1900
    else:
        return list(output_size)
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919


def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]:
    return [
        (crop_width - image_width) // 2 if crop_width > image_width else 0,
        (crop_height - image_height) // 2 if crop_height > image_height else 0,
        (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
        (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
    ]


def _center_crop_compute_crop_anchor(
    crop_height: int, crop_width: int, image_height: int, image_width: int
) -> Tuple[int, int]:
    crop_top = int(round((image_height - crop_height) / 2.0))
    crop_left = int(round((image_width - crop_width) / 2.0))
    return crop_top, crop_left


1920
@_register_kernel_internal(center_crop, torch.Tensor)
1921
@_register_kernel_internal(center_crop, tv_tensors.Image)
1922
def center_crop_image(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
1923
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1924
1925
1926
1927
    shape = image.shape
    if image.numel() == 0:
        return image.reshape(shape[:-2] + (crop_height, crop_width))
    image_height, image_width = shape[-2:]
1928
1929
1930

    if crop_height > image_height or crop_width > image_width:
        padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
1931
        image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0)
1932

1933
        image_height, image_width = image.shape[-2:]
1934
        if crop_width == image_width and crop_height == image_height:
1935
            return image
1936
1937

    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
1938
    return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]
1939
1940


1941
@_register_kernel_internal(center_crop, PIL.Image.Image)
1942
def _center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
1943
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1944
    image_height, image_width = _get_size_image_pil(image)
1945
1946
1947

    if crop_height > image_height or crop_width > image_width:
        padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
1948
        image = _pad_image_pil(image, padding_ltrb, fill=0)
1949

1950
        image_height, image_width = _get_size_image_pil(image)
1951
        if crop_width == image_width and crop_height == image_height:
1952
            return image
1953
1954

    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
1955
    return _crop_image_pil(image, crop_top, crop_left, crop_height, crop_width)
1956
1957


1958
1959
def center_crop_bounding_boxes(
    bounding_boxes: torch.Tensor,
1960
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1961
    canvas_size: Tuple[int, int],
1962
    output_size: List[int],
1963
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1964
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
Philip Meier's avatar
Philip Meier committed
1965
    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size)
1966
1967
1968
    return crop_bounding_boxes(
        bounding_boxes, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width
    )
1969
1970


1971
@_register_kernel_internal(center_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
1972
def _center_crop_bounding_boxes_dispatch(
1973
1974
    inpt: tv_tensors.BoundingBoxes, output_size: List[int]
) -> tv_tensors.BoundingBoxes:
1975
1976
1977
    output, canvas_size = center_crop_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size
    )
1978
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
1979
1980


1981
@_register_kernel_internal(center_crop, tv_tensors.Mask)
1982
1983
1984
def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1985
1986
1987
1988
        needs_squeeze = True
    else:
        needs_squeeze = False

1989
    output = center_crop_image(image=mask, output_size=output_size)
1990
1991
1992
1993
1994

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1995
1996


1997
@_register_kernel_internal(center_crop, tv_tensors.Video)
1998
def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor:
1999
    return center_crop_image(video, output_size)
2000
2001


2002
def resized_crop(
2003
    inpt: torch.Tensor,
2004
2005
2006
2007
2008
2009
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2010
    antialias: Optional[bool] = True,
2011
) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
2012
    """[BETA] See :class:`~torchvision.transforms.v2.RandomResizedCrop` for details."""
2013
    if torch.jit.is_scripting():
2014
        return resized_crop_image(
2015
2016
2017
2018
2019
2020
2021
2022
            inpt,
            top=top,
            left=left,
            height=height,
            width=width,
            size=size,
            interpolation=interpolation,
            antialias=antialias,
2023
        )
2024

2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
    _log_api_usage_once(resized_crop)

    kernel = _get_kernel(resized_crop, type(inpt))
    return kernel(
        inpt,
        top=top,
        left=left,
        height=height,
        width=width,
        size=size,
        interpolation=interpolation,
        antialias=antialias,
    )
2038

2039
2040

@_register_kernel_internal(resized_crop, torch.Tensor)
2041
@_register_kernel_internal(resized_crop, tv_tensors.Image)
2042
def resized_crop_image(
2043
    image: torch.Tensor,
2044
2045
2046
2047
2048
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
2049
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2050
    antialias: Optional[bool] = True,
2051
) -> torch.Tensor:
2052
2053
    image = crop_image(image, top, left, height, width)
    return resize_image(image, size, interpolation=interpolation, antialias=antialias)
2054
2055


2056
def _resized_crop_image_pil(
2057
    image: PIL.Image.Image,
2058
2059
2060
2061
2062
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
2063
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2064
) -> PIL.Image.Image:
2065
2066
    image = _crop_image_pil(image, top, left, height, width)
    return _resize_image_pil(image, size, interpolation=interpolation)
2067
2068


2069
@_register_kernel_internal(resized_crop, PIL.Image.Image)
2070
def _resized_crop_image_pil_dispatch(
2071
2072
2073
2074
2075
2076
2077
    image: PIL.Image.Image,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2078
    antialias: Optional[bool] = True,
2079
2080
2081
) -> PIL.Image.Image:
    if antialias is False:
        warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
2082
    return _resized_crop_image_pil(
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
        image,
        top=top,
        left=left,
        height=height,
        width=width,
        size=size,
        interpolation=interpolation,
    )


2093
2094
def resized_crop_bounding_boxes(
    bounding_boxes: torch.Tensor,
2095
    format: tv_tensors.BoundingBoxFormat,
2096
2097
2098
2099
2100
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
2101
) -> Tuple[torch.Tensor, Tuple[int, int]]:
2102
2103
2104
2105
    bounding_boxes, canvas_size = crop_bounding_boxes(bounding_boxes, format, top, left, height, width)
    return resize_bounding_boxes(bounding_boxes, canvas_size=canvas_size, size=size)


2106
@_register_kernel_internal(resized_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
2107
def _resized_crop_bounding_boxes_dispatch(
2108
2109
    inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: List[int], **kwargs
) -> tv_tensors.BoundingBoxes:
2110
2111
2112
    output, canvas_size = resized_crop_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size
    )
2113
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
2114
2115


2116
def resized_crop_mask(
2117
2118
2119
2120
2121
2122
2123
    mask: torch.Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
) -> torch.Tensor:
2124
2125
    mask = crop_mask(mask, top, left, height, width)
    return resize_mask(mask, size)
2126
2127


2128
@_register_kernel_internal(resized_crop, tv_tensors.Mask, tv_tensor_wrapper=False)
2129
def _resized_crop_mask_dispatch(
2130
2131
    inpt: tv_tensors.Mask, top: int, left: int, height: int, width: int, size: List[int], **kwargs
) -> tv_tensors.Mask:
2132
2133
2134
    output = resized_crop_mask(
        inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size
    )
2135
    return tv_tensors.wrap(output, like=inpt)
2136
2137


2138
@_register_kernel_internal(resized_crop, tv_tensors.Video)
2139
2140
2141
2142
2143
2144
2145
def resized_crop_video(
    video: torch.Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
2146
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2147
    antialias: Optional[bool] = True,
2148
) -> torch.Tensor:
2149
    return resized_crop_image(
2150
2151
2152
2153
        video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
    )


2154
def five_crop(
2155
2156
    inpt: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
Nicolas Hug's avatar
Nicolas Hug committed
2157
    """[BETA] See :class:`~torchvision.transforms.v2.FiveCrop` for details."""
2158
    if torch.jit.is_scripting():
2159
        return five_crop_image(inpt, size=size)
2160
2161
2162
2163
2164

    _log_api_usage_once(five_crop)

    kernel = _get_kernel(five_crop, type(inpt))
    return kernel(inpt, size=size)
2165
2166


2167
2168
def _parse_five_crop_size(size: List[int]) -> List[int]:
    if isinstance(size, numbers.Number):
2169
2170
        s = int(size)
        size = [s, s]
2171
    elif isinstance(size, (tuple, list)) and len(size) == 1:
2172
2173
        s = size[0]
        size = [s, s]
2174
2175
2176
2177
2178
2179
2180

    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")

    return size


2181
@_register_five_ten_crop_kernel_internal(five_crop, torch.Tensor)
2182
@_register_five_ten_crop_kernel_internal(five_crop, tv_tensors.Image)
2183
def five_crop_image(
2184
    image: torch.Tensor, size: List[int]
2185
2186
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    crop_height, crop_width = _parse_five_crop_size(size)
2187
    image_height, image_width = image.shape[-2:]
2188
2189

    if crop_width > image_width or crop_height > image_height:
2190
        raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
2191

2192
2193
2194
2195
2196
    tl = crop_image(image, 0, 0, crop_height, crop_width)
    tr = crop_image(image, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop_image(image, image_height - crop_height, 0, crop_height, crop_width)
    br = crop_image(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
    center = center_crop_image(image, [crop_height, crop_width])
2197
2198
2199
2200

    return tl, tr, bl, br, center


2201
@_register_five_ten_crop_kernel_internal(five_crop, PIL.Image.Image)
2202
def _five_crop_image_pil(
2203
    image: PIL.Image.Image, size: List[int]
2204
2205
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
    crop_height, crop_width = _parse_five_crop_size(size)
2206
    image_height, image_width = _get_size_image_pil(image)
2207
2208

    if crop_width > image_width or crop_height > image_height:
2209
        raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
2210

2211
2212
2213
2214
2215
    tl = _crop_image_pil(image, 0, 0, crop_height, crop_width)
    tr = _crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)
    bl = _crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width)
    br = _crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
    center = _center_crop_image_pil(image, [crop_height, crop_width])
2216
2217
2218
2219

    return tl, tr, bl, br, center


2220
@_register_five_ten_crop_kernel_internal(five_crop, tv_tensors.Video)
2221
2222
2223
def five_crop_video(
    video: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
2224
    return five_crop_image(video, size)
2225
2226


2227
def ten_crop(
2228
    inpt: torch.Tensor, size: List[int], vertical_flip: bool = False
2229
) -> Tuple[
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
2240
]:
Nicolas Hug's avatar
Nicolas Hug committed
2241
    """[BETA] See :class:`~torchvision.transforms.v2.TenCrop` for details."""
2242
    if torch.jit.is_scripting():
2243
        return ten_crop_image(inpt, size=size, vertical_flip=vertical_flip)
2244
2245
2246
2247
2248

    _log_api_usage_once(ten_crop)

    kernel = _get_kernel(ten_crop, type(inpt))
    return kernel(inpt, size=size, vertical_flip=vertical_flip)
2249
2250


2251
@_register_five_ten_crop_kernel_internal(ten_crop, torch.Tensor)
2252
@_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Image)
2253
def ten_crop_image(
Philip Meier's avatar
Philip Meier committed
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
    image: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
2267
    non_flipped = five_crop_image(image, size)
2268
2269

    if vertical_flip:
2270
        image = vertical_flip_image(image)
2271
    else:
2272
        image = horizontal_flip_image(image)
2273

2274
    flipped = five_crop_image(image, size)
2275

Philip Meier's avatar
Philip Meier committed
2276
    return non_flipped + flipped
2277
2278


2279
@_register_five_ten_crop_kernel_internal(ten_crop, PIL.Image.Image)
2280
def _ten_crop_image_pil(
Philip Meier's avatar
Philip Meier committed
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
    image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
) -> Tuple[
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
]:
2294
    non_flipped = _five_crop_image_pil(image, size)
2295
2296

    if vertical_flip:
2297
        image = _vertical_flip_image_pil(image)
2298
    else:
2299
        image = _horizontal_flip_image_pil(image)
2300

2301
    flipped = _five_crop_image_pil(image, size)
Philip Meier's avatar
Philip Meier committed
2302
2303
2304
2305

    return non_flipped + flipped


2306
@_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Video)
Philip Meier's avatar
Philip Meier committed
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
def ten_crop_video(
    video: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
2321
    return ten_crop_image(video, size, vertical_flip=vertical_flip)