_geometry.py 84.9 KB
Newer Older
1
import math
2
import numbers
3
import warnings
4
from typing import Any, List, Optional, Sequence, Tuple, Union
5
6
7

import PIL.Image
import torch
8
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
9

10
from torchvision import tv_tensors
11
12
from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _pad_symmetric
13
from torchvision.transforms.functional import (
14
    _check_antialias,
15
    _compute_resized_output_size as __compute_resized_output_size,
16
    _get_perspective_coeffs,
17
    _interpolation_modes_from_int,
18
    InterpolationMode,
19
    pil_modes_mapping,
20
21
    pil_to_tensor,
    to_pil_image,
22
)
23

24
25
from torchvision.utils import _log_api_usage_once

Nicolas Hug's avatar
Nicolas Hug committed
26
from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format
27

28
from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal
29

30

31
32
33
34
35
36
37
38
39
40
41
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
    if isinstance(interpolation, int):
        interpolation = _interpolation_modes_from_int(interpolation)
    elif not isinstance(interpolation, InterpolationMode):
        raise ValueError(
            f"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, "
            f"but got {interpolation}."
        )
    return interpolation


42
def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
43
    """[BETA] See :class:`~torchvision.transforms.v2.RandomHorizontalFlip` for details."""
44
    if torch.jit.is_scripting():
45
        return horizontal_flip_image(inpt)
46
47
48
49
50

    _log_api_usage_once(horizontal_flip)

    kernel = _get_kernel(horizontal_flip, type(inpt))
    return kernel(inpt)
51
52


53
@_register_kernel_internal(horizontal_flip, torch.Tensor)
54
@_register_kernel_internal(horizontal_flip, tv_tensors.Image)
55
def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor:
56
57
58
    return image.flip(-1)


59
@_register_kernel_internal(horizontal_flip, PIL.Image.Image)
60
def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
61
    return _FP.hflip(image)
62
63


64
@_register_kernel_internal(horizontal_flip, tv_tensors.Mask)
65
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
66
    return horizontal_flip_image(mask)
67
68


69
def horizontal_flip_bounding_boxes(
70
    bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: Tuple[int, int]
71
) -> torch.Tensor:
72
    shape = bounding_boxes.shape
73

74
    bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
75

76
    if format == tv_tensors.BoundingBoxFormat.XYXY:
Philip Meier's avatar
Philip Meier committed
77
        bounding_boxes[:, [2, 0]] = bounding_boxes[:, [0, 2]].sub_(canvas_size[1]).neg_()
78
    elif format == tv_tensors.BoundingBoxFormat.XYWH:
Philip Meier's avatar
Philip Meier committed
79
        bounding_boxes[:, 0].add_(bounding_boxes[:, 2]).sub_(canvas_size[1]).neg_()
80
    else:  # format == tv_tensors.BoundingBoxFormat.CXCYWH:
Philip Meier's avatar
Philip Meier committed
81
        bounding_boxes[:, 0].sub_(canvas_size[1]).neg_()
82

83
    return bounding_boxes.reshape(shape)
84
85


86
87
@_register_kernel_internal(horizontal_flip, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _horizontal_flip_bounding_boxes_dispatch(inpt: tv_tensors.BoundingBoxes) -> tv_tensors.BoundingBoxes:
88
89
90
    output = horizontal_flip_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
    )
91
    return tv_tensors.wrap(output, like=inpt)
92
93


94
@_register_kernel_internal(horizontal_flip, tv_tensors.Video)
95
def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
96
    return horizontal_flip_image(video)
97
98


99
def vertical_flip(inpt: torch.Tensor) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
100
    """[BETA] See :class:`~torchvision.transforms.v2.RandomVerticalFlip` for details."""
101
    if torch.jit.is_scripting():
102
        return vertical_flip_image(inpt)
103
104
105
106
107

    _log_api_usage_once(vertical_flip)

    kernel = _get_kernel(vertical_flip, type(inpt))
    return kernel(inpt)
108
109


110
@_register_kernel_internal(vertical_flip, torch.Tensor)
111
@_register_kernel_internal(vertical_flip, tv_tensors.Image)
112
def vertical_flip_image(image: torch.Tensor) -> torch.Tensor:
113
114
115
    return image.flip(-2)


116
@_register_kernel_internal(vertical_flip, PIL.Image.Image)
117
def _vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
Philip Meier's avatar
Philip Meier committed
118
    return _FP.vflip(image)
119
120


121
@_register_kernel_internal(vertical_flip, tv_tensors.Mask)
122
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
123
    return vertical_flip_image(mask)
124
125


126
def vertical_flip_bounding_boxes(
127
    bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: Tuple[int, int]
128
) -> torch.Tensor:
129
    shape = bounding_boxes.shape
130

131
    bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
132

133
    if format == tv_tensors.BoundingBoxFormat.XYXY:
Philip Meier's avatar
Philip Meier committed
134
        bounding_boxes[:, [1, 3]] = bounding_boxes[:, [3, 1]].sub_(canvas_size[0]).neg_()
135
    elif format == tv_tensors.BoundingBoxFormat.XYWH:
Philip Meier's avatar
Philip Meier committed
136
        bounding_boxes[:, 1].add_(bounding_boxes[:, 3]).sub_(canvas_size[0]).neg_()
137
    else:  # format == tv_tensors.BoundingBoxFormat.CXCYWH:
Philip Meier's avatar
Philip Meier committed
138
        bounding_boxes[:, 1].sub_(canvas_size[0]).neg_()
139

140
    return bounding_boxes.reshape(shape)
141
142


143
144
@_register_kernel_internal(vertical_flip, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def _vertical_flip_bounding_boxes_dispatch(inpt: tv_tensors.BoundingBoxes) -> tv_tensors.BoundingBoxes:
145
146
147
    output = vertical_flip_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
    )
148
    return tv_tensors.wrap(output, like=inpt)
149

150

151
@_register_kernel_internal(vertical_flip, tv_tensors.Video)
152
def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
153
    return vertical_flip_image(video)
154
155


156
157
158
159
160
161
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
hflip = horizontal_flip
vflip = vertical_flip


162
def _compute_resized_output_size(
Philip Meier's avatar
Philip Meier committed
163
    canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
164
165
166
) -> List[int]:
    if isinstance(size, int):
        size = [size]
167
168
169
170
171
    elif max_size is not None and len(size) != 1:
        raise ValueError(
            "max_size should only be passed if size specifies the length of the smaller edge, "
            "i.e. size should be an int or a sequence of length 1 in torchscript mode."
        )
Philip Meier's avatar
Philip Meier committed
172
    return __compute_resized_output_size(canvas_size, size=size, max_size=max_size)
173
174


175
def resize(
176
    inpt: torch.Tensor,
177
178
179
180
    size: List[int],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
    max_size: Optional[int] = None,
    antialias: Optional[Union[str, bool]] = "warn",
181
) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
182
    """[BETA] See :class:`~torchvision.transforms.v2.Resize` for details."""
183
    if torch.jit.is_scripting():
184
        return resize_image(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
185
186
187
188
189

    _log_api_usage_once(resize)

    kernel = _get_kernel(resize, type(inpt))
    return kernel(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
190
191


192
@_register_kernel_internal(resize, torch.Tensor)
193
@_register_kernel_internal(resize, tv_tensors.Image)
194
def resize_image(
195
196
    image: torch.Tensor,
    size: List[int],
197
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
198
    max_size: Optional[int] = None,
199
    antialias: Optional[Union[str, bool]] = "warn",
200
) -> torch.Tensor:
201
    interpolation = _check_interpolation(interpolation)
202
203
    antialias = _check_antialias(img=image, antialias=antialias, interpolation=interpolation)
    assert not isinstance(antialias, str)
204
    antialias = False if antialias is None else antialias
205
206
207
    align_corners: Optional[bool] = None
    if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC:
        align_corners = False
208
209
210
211
    else:
        # The default of antialias should be True from 0.17, so we don't warn or
        # error if other interpolation modes are used. This is documented.
        antialias = False
212

213
    shape = image.shape
214
    numel = image.numel()
215
    num_channels, old_height, old_width = shape[-3:]
vfdev's avatar
vfdev committed
216
    new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
217

218
219
    if (new_height, new_width) == (old_height, old_width):
        return image
220
    elif numel > 0:
221
        image = image.reshape(-1, num_channels, old_height, old_width)
222

223
        dtype = image.dtype
224
225
226
227
        acceptable_dtypes = [torch.float32, torch.float64]
        if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
            # uint8 dtype can be included for cpu and cuda input if nearest mode
            acceptable_dtypes.append(torch.uint8)
228
229
230
231
232
233
234
        elif image.device.type == "cpu":
            # uint8 dtype support for bilinear and bicubic is limited to cpu and
            # according to our benchmarks, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilinear
            if (interpolation == InterpolationMode.BILINEAR and "AVX2" in torch.backends.cpu.get_cpu_capability()) or (
                interpolation == InterpolationMode.BICUBIC
            ):
                acceptable_dtypes.append(torch.uint8)
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250

        strides = image.stride()
        if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
            # There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as
            # contiguous even though the input is un-ambiguously channels_last (https://github.com/pytorch/pytorch/issues/68430).
            # In particular this happens for the typical torchvision use-case of single CHW images where we fake the batch dim
            # to become 1CHW. Below, we restride those tensors to trick torch core into properly allocating the output as
            # channels_last, thus preserving the memory format of the input. This is not just for format consistency:
            # for uint8 bilinear images, this also avoids an extra copy (re-packing) of the output and saves time.
            # TODO: when https://github.com/pytorch/pytorch/issues/68430 is fixed (possibly by https://github.com/pytorch/pytorch/pull/100373),
            # we should be able to remove this hack.
            new_strides = list(strides)
            new_strides[0] = numel
            image = image.as_strided((1, num_channels, old_height, old_width), new_strides)

        need_cast = dtype not in acceptable_dtypes
251
252
253
254
        if need_cast:
            image = image.to(dtype=torch.float32)

        image = interpolate(
255
256
            image,
            size=[new_height, new_width],
257
258
            mode=interpolation.value,
            align_corners=align_corners,
259
260
            antialias=antialias,
        )
261

262
263
        if need_cast:
            if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
264
                # This path is hit on non-AVX archs, or on GPU.
265
                image = image.clamp_(min=0, max=255)
266
267
268
            if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
                image = image.round_()
            image = image.to(dtype=dtype)
269

270
    return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
271
272


273
def _resize_image_pil(
274
    image: PIL.Image.Image,
275
    size: Union[Sequence[int], int],
276
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
277
278
    max_size: Optional[int] = None,
) -> PIL.Image.Image:
279
280
281
282
283
284
285
    old_height, old_width = image.height, image.width
    new_height, new_width = _compute_resized_output_size(
        (old_height, old_width),
        size=size,  # type: ignore[arg-type]
        max_size=max_size,
    )

286
    interpolation = _check_interpolation(interpolation)
287
288
289
290
291

    if (new_height, new_width) == (old_height, old_width):
        return image

    return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation])
292
293


294
@_register_kernel_internal(resize, PIL.Image.Image)
295
def __resize_image_pil_dispatch(
296
297
298
299
300
301
302
303
    image: PIL.Image.Image,
    size: Union[Sequence[int], int],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
    max_size: Optional[int] = None,
    antialias: Optional[Union[str, bool]] = "warn",
) -> PIL.Image.Image:
    if antialias is False:
        warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
304
    return _resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size)
305
306


307
308
309
def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
310
311
312
313
        needs_squeeze = True
    else:
        needs_squeeze = False

314
    output = resize_image(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
315
316
317
318
319

    if needs_squeeze:
        output = output.squeeze(0)

    return output
320
321


322
@_register_kernel_internal(resize, tv_tensors.Mask, tv_tensor_wrapper=False)
323
def _resize_mask_dispatch(
324
325
    inpt: tv_tensors.Mask, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> tv_tensors.Mask:
326
    output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size)
327
    return tv_tensors.wrap(output, like=inpt)
328
329


330
def resize_bounding_boxes(
Philip Meier's avatar
Philip Meier committed
331
    bounding_boxes: torch.Tensor, canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
332
) -> Tuple[torch.Tensor, Tuple[int, int]]:
Philip Meier's avatar
Philip Meier committed
333
334
    old_height, old_width = canvas_size
    new_height, new_width = _compute_resized_output_size(canvas_size, size=size, max_size=max_size)
335
336

    if (new_height, new_width) == (old_height, old_width):
Philip Meier's avatar
Philip Meier committed
337
        return bounding_boxes, canvas_size
338

339
340
    w_ratio = new_width / old_width
    h_ratio = new_height / old_height
341
    ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_boxes.device)
342
    return (
343
        bounding_boxes.mul(ratios).to(bounding_boxes.dtype),
344
345
        (new_height, new_width),
    )
346
347


348
@_register_kernel_internal(resize, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
349
def _resize_bounding_boxes_dispatch(
350
351
    inpt: tv_tensors.BoundingBoxes, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> tv_tensors.BoundingBoxes:
352
353
354
    output, canvas_size = resize_bounding_boxes(
        inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size
    )
355
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
356
357


358
@_register_kernel_internal(resize, tv_tensors.Video)
359
360
361
def resize_video(
    video: torch.Tensor,
    size: List[int],
362
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
363
    max_size: Optional[int] = None,
364
    antialias: Optional[Union[str, bool]] = "warn",
365
) -> torch.Tensor:
366
    return resize_image(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
367
368


369
def affine(
370
    inpt: torch.Tensor,
371
372
373
374
375
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
376
    fill: _FillTypeJIT = None,
377
    center: Optional[List[float]] = None,
378
) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
379
    """[BETA] See :class:`~torchvision.transforms.v2.RandomAffine` for details."""
380
    if torch.jit.is_scripting():
381
        return affine_image(
382
            inpt,
383
            angle=angle,
384
385
386
387
388
389
390
            translate=translate,
            scale=scale,
            shear=shear,
            interpolation=interpolation,
            fill=fill,
            center=center,
        )
391
392
393
394
395
396
397
398
399
400
401
402
403
404

    _log_api_usage_once(affine)

    kernel = _get_kernel(affine, type(inpt))
    return kernel(
        inpt,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=interpolation,
        fill=fill,
        center=center,
    )
405
406


407
def _affine_parse_args(
408
    angle: Union[int, float],
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    translate: List[float],
    scale: float,
    shear: List[float],
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    center: Optional[List[float]] = None,
) -> Tuple[float, List[float], List[float], Optional[List[float]]]:
    if not isinstance(angle, (int, float)):
        raise TypeError("Argument angle should be int or float")

    if not isinstance(translate, (list, tuple)):
        raise TypeError("Argument translate should be a sequence")

    if len(translate) != 2:
        raise ValueError("Argument translate should be a sequence of length 2")

    if scale <= 0.0:
        raise ValueError("Argument scale should be positive")

    if not isinstance(shear, (numbers.Number, (list, tuple))):
        raise TypeError("Shear should be either a single value or a sequence of two values")

    if not isinstance(interpolation, InterpolationMode):
        raise TypeError("Argument interpolation should be a InterpolationMode")

    if isinstance(angle, int):
        angle = float(angle)

    if isinstance(translate, tuple):
        translate = list(translate)

    if isinstance(shear, numbers.Number):
        shear = [shear, 0.0]

    if isinstance(shear, tuple):
        shear = list(shear)

    if len(shear) == 1:
        shear = [shear[0], shear[0]]

    if len(shear) != 2:
        raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")

451
452
453
454
455
    if center is not None:
        if not isinstance(center, (list, tuple)):
            raise TypeError("Argument center should be a sequence")
        else:
            center = [float(c) for c in center]
456
457
458
459

    return angle, translate, shear, center


460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
def _get_inverse_affine_matrix(
    center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
) -> List[float]:
    # Helper method to compute inverse matrix for affine transformation

    # Pillow requires inverse affine transformation matrix:
    # Affine matrix is : M = T * C * RotateScaleShear * C^-1
    #
    # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
    #       C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
    #       RotateScaleShear is rotation with scale and shear matrix
    #
    #       RotateScaleShear(a, s, (sx, sy)) =
    #       = R(a) * S(s) * SHy(sy) * SHx(sx)
    #       = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
    #         [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
    #         [ 0                    , 0                                      , 1 ]
    # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
    # SHx(s) = [1, -tan(s)] and SHy(s) = [1      , 0]
    #          [0, 1      ]              [-tan(s), 1]
    #
    # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1

    rot = math.radians(angle)
    sx = math.radians(shear[0])
    sy = math.radians(shear[1])

    cx, cy = center
    tx, ty = translate

    # Cached results
    cos_sy = math.cos(sy)
    tan_sx = math.tan(sx)
    rot_minus_sy = rot - sy
    cx_plus_tx = cx + tx
    cy_plus_ty = cy + ty

    # Rotate Scale Shear (RSS) without scaling
    a = math.cos(rot_minus_sy) / cos_sy
    b = -(a * tan_sx + math.sin(rot))
    c = math.sin(rot_minus_sy) / cos_sy
    d = math.cos(rot) - c * tan_sx

    if inverted:
        # Inverted rotation matrix with scale and shear
        # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
        matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0]
        # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
        # and then apply center translation: C * RSS^-1 * C^-1 * T^-1
        matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty
        matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty
    else:
        matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0]
        # Apply inverse of center translation: RSS * C^-1
        # and then apply translation and center : T * C * RSS * C^-1
        matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy
        matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy

    return matrix


def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
    # Inspired of PIL implementation:
    # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054

    # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
    # Points are shifted due to affine matrix torch convention about
    # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
    half_w = 0.5 * w
    half_h = 0.5 * h
    pts = torch.tensor(
        [
            [-half_w, -half_h, 1.0],
            [-half_w, half_h, 1.0],
            [half_w, half_h, 1.0],
            [half_w, -half_h, 1.0],
        ]
    )
    theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
    new_pts = torch.matmul(pts, theta.T)
    min_vals, max_vals = new_pts.aminmax(dim=0)

    # shift points to [0, w] and [0, h] interval to match PIL results
    halfs = torch.tensor((half_w, half_h))
    min_vals.add_(halfs)
    max_vals.add_(halfs)

    # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
    tol = 1e-4
    inv_tol = 1.0 / tol
    cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_()
    cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_()
    size = cmax.sub_(cmin)
    return int(size[0]), int(size[1])  # w, h


556
def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill: _FillTypeJIT) -> torch.Tensor:
557

558
559
560
561
    # We are using context knowledge that grid should have float dtype
    fp = img.dtype == grid.dtype
    float_img = img if fp else img.to(grid.dtype)

562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
    shape = float_img.shape
    if shape[0] > 1:
        # Apply same grid to a batch of images
        grid = grid.expand(shape[0], -1, -1, -1)

    # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
    if fill is not None:
        mask = torch.ones((shape[0], 1, shape[2], shape[3]), dtype=float_img.dtype, device=float_img.device)
        float_img = torch.cat((float_img, mask), dim=1)

    float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False)

    # Fill with required color
    if fill is not None:
        float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
        mask = mask.expand_as(float_img)
578
        fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)]  # type: ignore[arg-type]
579
580
581
582
583
584
585
586
587
        fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
        if mode == "nearest":
            bool_mask = mask < 0.5
            float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
        else:  # 'bilinear'
            # The following is mathematically equivalent to:
            # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
            float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)

588
589
590
    img = float_img.round_().to(img.dtype) if not fp else float_img

    return img
591
592
593
594
595
596


def _assert_grid_transform_inputs(
    image: torch.Tensor,
    matrix: Optional[List[float]],
    interpolation: str,
597
    fill: _FillTypeJIT,
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
    supported_interpolation_modes: List[str],
    coeffs: Optional[List[float]] = None,
) -> None:
    if matrix is not None:
        if not isinstance(matrix, list):
            raise TypeError("Argument matrix should be a list")
        elif len(matrix) != 6:
            raise ValueError("Argument matrix should have 6 float values")

    if coeffs is not None and len(coeffs) != 8:
        raise ValueError("Argument coeffs should have 8 float values")

    if fill is not None:
        if isinstance(fill, (tuple, list)):
            length = len(fill)
            num_channels = image.shape[-3]
            if length > 1 and length != num_channels:
                raise ValueError(
                    "The number of elements in 'fill' cannot broadcast to match the number of "
                    f"channels of the image ({length} != {num_channels})"
                )
        elif not isinstance(fill, (int, float)):
            raise ValueError("Argument fill should be either int, float, tuple or list")

    if interpolation not in supported_interpolation_modes:
        raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")


def _affine_grid(
    theta: torch.Tensor,
    w: int,
    h: int,
    ow: int,
    oh: int,
) -> torch.Tensor:
    # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
    # AffineGridGenerator.cpp#L18
    # Difference with AffineGridGenerator is that:
    # 1) we normalize grid values after applying theta
    # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
    dtype = theta.dtype
    device = theta.device

    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
    x_grid = torch.linspace((1.0 - ow) * 0.5, (ow - 1.0) * 0.5, steps=ow, device=device)
    base_grid[..., 0].copy_(x_grid)
    y_grid = torch.linspace((1.0 - oh) * 0.5, (oh - 1.0) * 0.5, steps=oh, device=device).unsqueeze_(-1)
    base_grid[..., 1].copy_(y_grid)
    base_grid[..., 2].fill_(1)

    rescaled_theta = theta.transpose(1, 2).div_(torch.tensor([0.5 * w, 0.5 * h], dtype=dtype, device=device))
    output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
    return output_grid.view(1, oh, ow, 2)


653
@_register_kernel_internal(affine, torch.Tensor)
654
@_register_kernel_internal(affine, tv_tensors.Image)
655
def affine_image(
656
    image: torch.Tensor,
657
    angle: Union[int, float],
658
659
660
    translate: List[float],
    scale: float,
    shear: List[float],
661
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
662
    fill: _FillTypeJIT = None,
663
664
    center: Optional[List[float]] = None,
) -> torch.Tensor:
665
666
    interpolation = _check_interpolation(interpolation)

667
668
    if image.numel() == 0:
        return image
669

670
    shape = image.shape
671
    ndim = image.ndim
672

673
674
675
676
677
678
679
680
681
682
    if ndim > 4:
        image = image.reshape((-1,) + shape[-3:])
        needs_unsquash = True
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
    else:
        needs_unsquash = False

    height, width = shape[-2:]
683
684
685
686
687
    angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

    center_f = [0.0, 0.0]
    if center is not None:
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
688
        center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
689

690
    translate_f = [float(t) for t in translate]
691
692
    matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)

693
694
    _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])

695
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
696
697
    theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
    grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
698
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
699
700
701
702
703

    if needs_unsquash:
        output = output.reshape(shape)

    return output
704
705


706
@_register_kernel_internal(affine, PIL.Image.Image)
707
def _affine_image_pil(
708
    image: PIL.Image.Image,
709
    angle: Union[int, float],
710
711
712
    translate: List[float],
    scale: float,
    shear: List[float],
713
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
714
    fill: _FillTypeJIT = None,
715
716
    center: Optional[List[float]] = None,
) -> PIL.Image.Image:
717
    interpolation = _check_interpolation(interpolation)
718
719
720
721
722
723
    angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

    # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
    # it is visually better to estimate the center without 0.5 offset
    # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
    if center is None:
724
        height, width = _get_size_image_pil(image)
725
726
727
        center = [width * 0.5, height * 0.5]
    matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)

728
    return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
729
730


731
732
def _affine_bounding_boxes_with_expand(
    bounding_boxes: torch.Tensor,
733
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
734
    canvas_size: Tuple[int, int],
735
736
737
738
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
739
    center: Optional[List[float]] = None,
740
    expand: bool = False,
741
) -> Tuple[torch.Tensor, Tuple[int, int]]:
742
    if bounding_boxes.numel() == 0:
Philip Meier's avatar
Philip Meier committed
743
        return bounding_boxes, canvas_size
744
745
746
747
748
749
750

    original_shape = bounding_boxes.shape
    original_dtype = bounding_boxes.dtype
    bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
    dtype = bounding_boxes.dtype
    device = bounding_boxes.device
    bounding_boxes = (
Nicolas Hug's avatar
Nicolas Hug committed
751
        convert_bounding_box_format(
752
            bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True
753
754
755
        )
    ).reshape(-1, 4)

756
757
758
    angle, translate, shear, center = _affine_parse_args(
        angle, translate, scale, shear, InterpolationMode.NEAREST, center
    )
759

760
    if center is None:
Philip Meier's avatar
Philip Meier committed
761
        height, width = canvas_size
762
763
        center = [width * 0.5, height * 0.5]

764
765
766
767
768
769
770
    affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
    transposed_affine_matrix = (
        torch.tensor(
            affine_vector,
            dtype=dtype,
            device=device,
        )
771
        .reshape(2, 3)
772
773
        .T
    )
774
775
776
777
    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
778
    points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
779
    points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
780
    # 2) Now let's transform the points using affine matrix
781
    transformed_points = torch.matmul(points, transposed_affine_matrix)
782
783
    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
784
    transformed_points = transformed_points.reshape(-1, 4, 2)
785
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
786
    out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
787
788
789
790

    if expand:
        # Compute minimum point for transformed image frame:
        # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
Philip Meier's avatar
Philip Meier committed
791
        height, width = canvas_size
792
793
794
        points = torch.tensor(
            [
                [0.0, 0.0, 1.0],
795
796
797
                [0.0, float(height), 1.0],
                [float(width), float(height), 1.0],
                [float(width), 0.0, 1.0],
798
799
800
801
            ],
            dtype=dtype,
            device=device,
        )
802
        new_points = torch.matmul(points, transposed_affine_matrix)
803
        tr = torch.amin(new_points, dim=0, keepdim=True)
804
        # Translate bounding boxes
805
        out_bboxes.sub_(tr.repeat((1, 2)))
806
807
        # Estimate meta-data for image with inverted=True
        affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
808
        new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
Philip Meier's avatar
Philip Meier committed
809
        canvas_size = (new_height, new_width)
810

811
    out_bboxes = clamp_bounding_boxes(out_bboxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size)
Nicolas Hug's avatar
Nicolas Hug committed
812
    out_bboxes = convert_bounding_box_format(
813
        out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
814
815
816
    ).reshape(original_shape)

    out_bboxes = out_bboxes.to(original_dtype)
Philip Meier's avatar
Philip Meier committed
817
    return out_bboxes, canvas_size
818
819


820
821
def affine_bounding_boxes(
    bounding_boxes: torch.Tensor,
822
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
823
    canvas_size: Tuple[int, int],
824
    angle: Union[int, float],
825
826
827
828
829
    translate: List[float],
    scale: float,
    shear: List[float],
    center: Optional[List[float]] = None,
) -> torch.Tensor:
830
831
    out_box, _ = _affine_bounding_boxes_with_expand(
        bounding_boxes,
832
        format=format,
Philip Meier's avatar
Philip Meier committed
833
        canvas_size=canvas_size,
834
835
836
837
838
839
840
841
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        center=center,
        expand=False,
    )
    return out_box
842
843


844
@_register_kernel_internal(affine, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
845
def _affine_bounding_boxes_dispatch(
846
    inpt: tv_tensors.BoundingBoxes,
847
848
849
850
851
852
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
    center: Optional[List[float]] = None,
    **kwargs,
853
) -> tv_tensors.BoundingBoxes:
854
855
856
857
858
859
860
861
862
863
    output = affine_bounding_boxes(
        inpt.as_subclass(torch.Tensor),
        format=inpt.format,
        canvas_size=inpt.canvas_size,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        center=center,
    )
864
    return tv_tensors.wrap(output, like=inpt)
865
866


867
868
def affine_mask(
    mask: torch.Tensor,
869
    angle: Union[int, float],
870
871
872
    translate: List[float],
    scale: float,
    shear: List[float],
873
    fill: _FillTypeJIT = None,
874
875
    center: Optional[List[float]] = None,
) -> torch.Tensor:
876
877
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
878
879
880
881
        needs_squeeze = True
    else:
        needs_squeeze = False

882
    output = affine_image(
883
        mask,
884
885
886
887
888
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=InterpolationMode.NEAREST,
889
        fill=fill,
890
891
892
        center=center,
    )

893
894
895
896
897
    if needs_squeeze:
        output = output.squeeze(0)

    return output

898

899
@_register_kernel_internal(affine, tv_tensors.Mask, tv_tensor_wrapper=False)
900
def _affine_mask_dispatch(
901
    inpt: tv_tensors.Mask,
902
903
904
905
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
906
    fill: _FillTypeJIT = None,
907
908
    center: Optional[List[float]] = None,
    **kwargs,
909
) -> tv_tensors.Mask:
910
911
912
913
914
915
916
917
918
    output = affine_mask(
        inpt.as_subclass(torch.Tensor),
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        fill=fill,
        center=center,
    )
919
    return tv_tensors.wrap(output, like=inpt)
920
921


922
@_register_kernel_internal(affine, tv_tensors.Video)
923
924
925
926
927
928
def affine_video(
    video: torch.Tensor,
    angle: Union[int, float],
    translate: List[float],
    scale: float,
    shear: List[float],
929
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
930
    fill: _FillTypeJIT = None,
931
932
    center: Optional[List[float]] = None,
) -> torch.Tensor:
933
    return affine_image(
934
935
936
937
938
939
940
941
942
943
944
        video,
        angle=angle,
        translate=translate,
        scale=scale,
        shear=shear,
        interpolation=interpolation,
        fill=fill,
        center=center,
    )


945
def rotate(
946
    inpt: torch.Tensor,
947
    angle: float,
948
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
949
    expand: bool = False,
950
    center: Optional[List[float]] = None,
951
952
    fill: _FillTypeJIT = None,
) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
953
    """[BETA] See :class:`~torchvision.transforms.v2.RandomRotation` for details."""
954
    if torch.jit.is_scripting():
955
        return rotate_image(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
956

957
    _log_api_usage_once(rotate)
958

959
960
961
962
963
    kernel = _get_kernel(rotate, type(inpt))
    return kernel(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center)


@_register_kernel_internal(rotate, torch.Tensor)
964
@_register_kernel_internal(rotate, tv_tensors.Image)
965
def rotate_image(
966
    image: torch.Tensor,
967
    angle: float,
968
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
969
970
    expand: bool = False,
    center: Optional[List[float]] = None,
971
    fill: _FillTypeJIT = None,
972
) -> torch.Tensor:
973
974
    interpolation = _check_interpolation(interpolation)

975
976
    shape = image.shape
    num_channels, height, width = shape[-3:]
977

978
979
    center_f = [0.0, 0.0]
    if center is not None:
980
981
        # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
        center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
982
983
984
985

    # due to current incoherence of rotation angle direction between affine and rotate implementations
    # we need to set -angle.
    matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
986

987
    if image.numel() > 0:
988
989
990
991
992
        image = image.reshape(-1, num_channels, height, width)

        _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])

        ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
993
        dtype = image.dtype if torch.is_floating_point(image) else torch.float32
994
995
        theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
        grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh)
996
        output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
997
998

        new_height, new_width = output.shape[-2:]
999
    else:
1000
1001
        output = image
        new_width, new_height = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
1002

1003
    return output.reshape(shape[:-3] + (num_channels, new_height, new_width))
1004
1005


1006
@_register_kernel_internal(rotate, PIL.Image.Image)
1007
def _rotate_image_pil(
1008
    image: PIL.Image.Image,
1009
    angle: float,
1010
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
1011
1012
    expand: bool = False,
    center: Optional[List[float]] = None,
1013
    fill: _FillTypeJIT = None,
1014
) -> PIL.Image.Image:
1015
1016
    interpolation = _check_interpolation(interpolation)

1017
    return _FP.rotate(
1018
        image, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
1019
1020
1021
    )


1022
1023
def rotate_bounding_boxes(
    bounding_boxes: torch.Tensor,
1024
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1025
    canvas_size: Tuple[int, int],
1026
1027
1028
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
1029
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1030
1031
    return _affine_bounding_boxes_with_expand(
        bounding_boxes,
1032
        format=format,
Philip Meier's avatar
Philip Meier committed
1033
        canvas_size=canvas_size,
1034
1035
1036
1037
1038
1039
1040
        angle=-angle,
        translate=[0.0, 0.0],
        scale=1.0,
        shear=[0.0, 0.0],
        center=center,
        expand=expand,
    )
1041
1042


1043
@_register_kernel_internal(rotate, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
1044
def _rotate_bounding_boxes_dispatch(
1045
1046
    inpt: tv_tensors.BoundingBoxes, angle: float, expand: bool = False, center: Optional[List[float]] = None, **kwargs
) -> tv_tensors.BoundingBoxes:
1047
1048
1049
1050
1051
1052
1053
1054
    output, canvas_size = rotate_bounding_boxes(
        inpt.as_subclass(torch.Tensor),
        format=inpt.format,
        canvas_size=inpt.canvas_size,
        angle=angle,
        expand=expand,
        center=center,
    )
1055
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
1056
1057


1058
1059
def rotate_mask(
    mask: torch.Tensor,
1060
1061
1062
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
1063
    fill: _FillTypeJIT = None,
1064
) -> torch.Tensor:
1065
1066
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1067
1068
1069
1070
        needs_squeeze = True
    else:
        needs_squeeze = False

1071
    output = rotate_image(
1072
        mask,
1073
1074
1075
        angle=angle,
        expand=expand,
        interpolation=InterpolationMode.NEAREST,
1076
        fill=fill,
1077
1078
1079
        center=center,
    )

1080
1081
1082
1083
1084
    if needs_squeeze:
        output = output.squeeze(0)

    return output

1085

1086
@_register_kernel_internal(rotate, tv_tensors.Mask, tv_tensor_wrapper=False)
1087
def _rotate_mask_dispatch(
1088
    inpt: tv_tensors.Mask,
1089
1090
1091
    angle: float,
    expand: bool = False,
    center: Optional[List[float]] = None,
1092
    fill: _FillTypeJIT = None,
1093
    **kwargs,
1094
) -> tv_tensors.Mask:
1095
    output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center)
1096
    return tv_tensors.wrap(output, like=inpt)
1097
1098


1099
@_register_kernel_internal(rotate, tv_tensors.Video)
1100
1101
1102
def rotate_video(
    video: torch.Tensor,
    angle: float,
1103
    interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
1104
1105
    expand: bool = False,
    center: Optional[List[float]] = None,
1106
    fill: _FillTypeJIT = None,
1107
) -> torch.Tensor:
1108
    return rotate_image(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
1109
1110


1111
def pad(
1112
    inpt: torch.Tensor,
1113
1114
1115
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
    padding_mode: str = "constant",
1116
) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
1117
    """[BETA] See :class:`~torchvision.transforms.v2.Pad` for details."""
1118
    if torch.jit.is_scripting():
1119
        return pad_image(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
1120

1121
    _log_api_usage_once(pad)
1122

1123
1124
    kernel = _get_kernel(pad, type(inpt))
    return kernel(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
1125
1126


1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
    if isinstance(padding, int):
        pad_left = pad_right = pad_top = pad_bottom = padding
    elif isinstance(padding, (tuple, list)):
        if len(padding) == 1:
            pad_left = pad_right = pad_top = pad_bottom = padding[0]
        elif len(padding) == 2:
            pad_left = pad_right = padding[0]
            pad_top = pad_bottom = padding[1]
        elif len(padding) == 4:
            pad_left = padding[0]
            pad_top = padding[1]
            pad_right = padding[2]
            pad_bottom = padding[3]
        else:
            raise ValueError(
                f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
            )
    else:
        raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}")

    return [pad_left, pad_right, pad_top, pad_bottom]
1149

1150

1151
@_register_kernel_internal(pad, torch.Tensor)
1152
@_register_kernel_internal(pad, tv_tensors.Image)
1153
def pad_image(
1154
    image: torch.Tensor,
1155
1156
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1157
1158
    padding_mode: str = "constant",
) -> torch.Tensor:
1159
    # Be aware that while `padding` has order `[left, top, right, bottom]`, `torch_padding` uses
1160
1161
1162
1163
    # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
    # internally.
    torch_padding = _parse_pad_padding(padding)

1164
    if padding_mode not in ("constant", "edge", "reflect", "symmetric"):
1165
1166
1167
1168
1169
        raise ValueError(
            f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
            f"but got `'{padding_mode}'`."
        )

1170
    if fill is None:
1171
1172
1173
1174
1175
1176
        fill = 0

    if isinstance(fill, (int, float)):
        return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
    elif len(fill) == 1:
        return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode)
1177
    else:
1178
        return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
1179
1180
1181


def _pad_with_scalar_fill(
1182
    image: torch.Tensor,
1183
1184
1185
    torch_padding: List[int],
    fill: Union[int, float],
    padding_mode: str,
1186
) -> torch.Tensor:
1187
1188
    shape = image.shape
    num_channels, height, width = shape[-3:]
1189

1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
    batch_size = 1
    for s in shape[:-3]:
        batch_size *= s

    image = image.reshape(batch_size, num_channels, height, width)

    if padding_mode == "edge":
        # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
        # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
        # name.
        padding_mode = "replicate"

    if padding_mode == "constant":
        image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
    elif padding_mode in ("reflect", "replicate"):
        # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
        # TODO: See https://github.com/pytorch/pytorch/issues/40763
        dtype = image.dtype
        if not image.is_floating_point():
            needs_cast = True
            image = image.to(torch.float32)
        else:
            needs_cast = False
1213

1214
1215
1216
1217
1218
        image = torch_pad(image, torch_padding, mode=padding_mode)

        if needs_cast:
            image = image.to(dtype)
    else:  # padding_mode == "symmetric"
1219
        image = _pad_symmetric(image, torch_padding)
1220
1221

    new_height, new_width = image.shape[-2:]
1222

1223
    return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
1224
1225


1226
# TODO: This should be removed once torch_pad supports non-scalar padding values
1227
def _pad_with_vector_fill(
1228
    image: torch.Tensor,
1229
    torch_padding: List[int],
1230
    fill: List[float],
1231
    padding_mode: str,
1232
1233
1234
1235
) -> torch.Tensor:
    if padding_mode != "constant":
        raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")

1236
1237
    output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
    left, right, top, bottom = torch_padding
1238
1239
1240
1241
1242

    # We are creating the tensor in the autodetected dtype first and convert to the right one after to avoid an implicit
    # float -> int conversion. That happens for example for the valid input of a uint8 image with floating point fill
    # value.
    fill = torch.tensor(fill, device=image.device).to(dtype=image.dtype).reshape(-1, 1, 1)
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254

    if top > 0:
        output[..., :top, :] = fill
    if left > 0:
        output[..., :, :left] = fill
    if bottom > 0:
        output[..., -bottom:, :] = fill
    if right > 0:
        output[..., :, -right:] = fill
    return output


1255
_pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad)
1256
1257


1258
@_register_kernel_internal(pad, tv_tensors.Mask)
1259
1260
def pad_mask(
    mask: torch.Tensor,
1261
1262
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1263
1264
    padding_mode: str = "constant",
) -> torch.Tensor:
1265
1266
1267
    if fill is None:
        fill = 0

1268
    if isinstance(fill, (tuple, list)):
1269
1270
        raise ValueError("Non-scalar fill value is not supported")

1271
1272
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1273
1274
1275
1276
        needs_squeeze = True
    else:
        needs_squeeze = False

1277
    output = pad_image(mask, padding=padding, fill=fill, padding_mode=padding_mode)
1278
1279
1280
1281
1282

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1283
1284


1285
1286
def pad_bounding_boxes(
    bounding_boxes: torch.Tensor,
1287
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1288
    canvas_size: Tuple[int, int],
1289
    padding: List[int],
vfdev's avatar
vfdev committed
1290
    padding_mode: str = "constant",
1291
) -> Tuple[torch.Tensor, Tuple[int, int]]:
vfdev's avatar
vfdev committed
1292
1293
1294
1295
    if padding_mode not in ["constant"]:
        # TODO: add support of other padding modes
        raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")

1296
    left, right, top, bottom = _parse_pad_padding(padding)
1297

1298
    if format == tv_tensors.BoundingBoxFormat.XYXY:
1299
1300
1301
        pad = [left, top, left, top]
    else:
        pad = [left, top, 0, 0]
1302
    bounding_boxes = bounding_boxes + torch.tensor(pad, dtype=bounding_boxes.dtype, device=bounding_boxes.device)
1303

Philip Meier's avatar
Philip Meier committed
1304
    height, width = canvas_size
1305
1306
    height += top + bottom
    width += left + right
Philip Meier's avatar
Philip Meier committed
1307
    canvas_size = (height, width)
1308

Philip Meier's avatar
Philip Meier committed
1309
    return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
1310
1311


1312
@_register_kernel_internal(pad, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
1313
def _pad_bounding_boxes_dispatch(
1314
1315
    inpt: tv_tensors.BoundingBoxes, padding: List[int], padding_mode: str = "constant", **kwargs
) -> tv_tensors.BoundingBoxes:
1316
1317
1318
1319
1320
1321
1322
    output, canvas_size = pad_bounding_boxes(
        inpt.as_subclass(torch.Tensor),
        format=inpt.format,
        canvas_size=inpt.canvas_size,
        padding=padding,
        padding_mode=padding_mode,
    )
1323
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
1324
1325


1326
@_register_kernel_internal(pad, tv_tensors.Video)
1327
1328
def pad_video(
    video: torch.Tensor,
1329
1330
    padding: List[int],
    fill: Optional[Union[int, float, List[float]]] = None,
1331
1332
    padding_mode: str = "constant",
) -> torch.Tensor:
1333
    return pad_image(video, padding, fill=fill, padding_mode=padding_mode)
1334
1335


1336
def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
1337
    """[BETA] See :class:`~torchvision.transforms.v2.RandomCrop` for details."""
1338
    if torch.jit.is_scripting():
1339
        return crop_image(inpt, top=top, left=left, height=height, width=width)
1340
1341

    _log_api_usage_once(crop)
1342

1343
1344
    kernel = _get_kernel(crop, type(inpt))
    return kernel(inpt, top=top, left=left, height=height, width=width)
1345

1346
1347

@_register_kernel_internal(crop, torch.Tensor)
1348
@_register_kernel_internal(crop, tv_tensors.Image)
1349
def crop_image(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
    h, w = image.shape[-2:]

    right = left + width
    bottom = top + height

    if left < 0 or top < 0 or right > w or bottom > h:
        image = image[..., max(top, 0) : bottom, max(left, 0) : right]
        torch_padding = [
            max(min(right, 0) - left, 0),
            max(right - max(w, left), 0),
            max(min(bottom, 0) - top, 0),
            max(bottom - max(h, top), 0),
        ]
        return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
    return image[..., top:bottom, left:right]


1367
1368
_crop_image_pil = _FP.crop
_register_kernel_internal(crop, PIL.Image.Image)(_crop_image_pil)
1369
1370


1371
1372
def crop_bounding_boxes(
    bounding_boxes: torch.Tensor,
1373
    format: tv_tensors.BoundingBoxFormat,
1374
1375
    top: int,
    left: int,
1376
1377
1378
    height: int,
    width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1379

1380
    # Crop or implicit pad if left and/or top have negative values:
1381
    if format == tv_tensors.BoundingBoxFormat.XYXY:
1382
        sub = [left, top, left, top]
1383
    else:
1384
1385
        sub = [left, top, 0, 0]

1386
    bounding_boxes = bounding_boxes - torch.tensor(sub, dtype=bounding_boxes.dtype, device=bounding_boxes.device)
Philip Meier's avatar
Philip Meier committed
1387
    canvas_size = (height, width)
1388

Philip Meier's avatar
Philip Meier committed
1389
    return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
1390
1391


1392
@_register_kernel_internal(crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
1393
def _crop_bounding_boxes_dispatch(
1394
1395
    inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int
) -> tv_tensors.BoundingBoxes:
1396
1397
1398
    output, canvas_size = crop_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width
    )
1399
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
1400
1401


1402
@_register_kernel_internal(crop, tv_tensors.Mask)
1403
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
1404
1405
1406
1407
1408
1409
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
        needs_squeeze = True
    else:
        needs_squeeze = False

1410
    output = crop_image(mask, top, left, height, width)
1411
1412
1413
1414
1415

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1416
1417


1418
@_register_kernel_internal(crop, tv_tensors.Video)
1419
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
1420
    return crop_image(video, top, left, height, width)
1421
1422


1423
def perspective(
1424
    inpt: torch.Tensor,
1425
1426
1427
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1428
    fill: _FillTypeJIT = None,
1429
    coefficients: Optional[List[float]] = None,
1430
) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
1431
    """[BETA] See :class:`~torchvision.transforms.v2.RandomPerspective` for details."""
1432
    if torch.jit.is_scripting():
1433
        return perspective_image(
1434
1435
1436
1437
1438
1439
            inpt,
            startpoints=startpoints,
            endpoints=endpoints,
            interpolation=interpolation,
            fill=fill,
            coefficients=coefficients,
1440
        )
1441

1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
    _log_api_usage_once(perspective)

    kernel = _get_kernel(perspective, type(inpt))
    return kernel(
        inpt,
        startpoints=startpoints,
        endpoints=endpoints,
        interpolation=interpolation,
        fill=fill,
        coefficients=coefficients,
    )

1454

1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
    # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
    # src/libImaging/Geometry.c#L394

    #
    # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #
    theta1 = torch.tensor(
        [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
    )
    theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)

    d = 0.5
    base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
1470
    x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype)
1471
    base_grid[..., 0].copy_(x_grid)
1472
    y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1)
1473
1474
1475
1476
    base_grid[..., 1].copy_(y_grid)
    base_grid[..., 2].fill_(1)

    rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device))
1477
1478
1479
    shape = (1, oh * ow, 3)
    output_grid1 = base_grid.view(shape).bmm(rescaled_theta1)
    output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2))
1480
1481
1482
1483
1484

    output_grid = output_grid1.div_(output_grid2).sub_(1.0)
    return output_grid.view(1, oh, ow, 2)


1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
def _perspective_coefficients(
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]],
) -> List[float]:
    if coefficients is not None:
        if startpoints is not None and endpoints is not None:
            raise ValueError("The startpoints/endpoints and the coefficients shouldn't be defined concurrently.")
        elif len(coefficients) != 8:
            raise ValueError("Argument coefficients should have 8 float values")
        return coefficients
    elif startpoints is not None and endpoints is not None:
        return _get_perspective_coeffs(startpoints, endpoints)
    else:
        raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.")


1502
@_register_kernel_internal(perspective, torch.Tensor)
1503
@_register_kernel_internal(perspective, tv_tensors.Image)
1504
def perspective_image(
1505
    image: torch.Tensor,
1506
1507
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1508
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1509
    fill: _FillTypeJIT = None,
1510
    coefficients: Optional[List[float]] = None,
1511
) -> torch.Tensor:
1512
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1513
1514
    interpolation = _check_interpolation(interpolation)

1515
1516
1517
1518
    if image.numel() == 0:
        return image

    shape = image.shape
1519
    ndim = image.ndim
1520

1521
    if ndim > 4:
1522
        image = image.reshape((-1,) + shape[-3:])
1523
        needs_unsquash = True
1524
1525
1526
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
1527
1528
1529
    else:
        needs_unsquash = False

1530
    _assert_grid_transform_inputs(
1531
1532
1533
1534
1535
1536
1537
1538
        image,
        matrix=None,
        interpolation=interpolation.value,
        fill=fill,
        supported_interpolation_modes=["nearest", "bilinear"],
        coeffs=perspective_coeffs,
    )

1539
    oh, ow = shape[-2:]
1540
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1541
    grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
1542
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1543
1544

    if needs_unsquash:
1545
        output = output.reshape(shape)
1546
1547

    return output
1548
1549


1550
@_register_kernel_internal(perspective, PIL.Image.Image)
1551
def _perspective_image_pil(
1552
    image: PIL.Image.Image,
1553
1554
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1555
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BICUBIC,
1556
    fill: _FillTypeJIT = None,
1557
    coefficients: Optional[List[float]] = None,
1558
) -> PIL.Image.Image:
1559
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1560
    interpolation = _check_interpolation(interpolation)
1561
    return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
1562
1563


1564
1565
def perspective_bounding_boxes(
    bounding_boxes: torch.Tensor,
1566
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1567
    canvas_size: Tuple[int, int],
1568
1569
1570
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]] = None,
1571
) -> torch.Tensor:
1572
1573
    if bounding_boxes.numel() == 0:
        return bounding_boxes
1574

1575
    perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
1576

1577
    original_shape = bounding_boxes.shape
Nicolas Hug's avatar
Nicolas Hug committed
1578
    # TODO: first cast to float if bbox is int64 before convert_bounding_box_format
1579
    bounding_boxes = (
1580
        convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
1581
    ).reshape(-1, 4)
1582

1583
1584
    dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32
    device = bounding_boxes.device
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615

    # perspective_coeffs are computed as endpoint -> start point
    # We have to invert perspective_coeffs for bboxes:
    # (x, y) - end point and (x_out, y_out) - start point
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # and we would like to get:
    # x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # and compute inv_coeffs in terms of coeffs

    denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3]
    if denom == 0:
        raise RuntimeError(
            f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. "
            f"Denominator is zero, denom={denom}"
        )

    inv_coeffs = [
        (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom,
        (-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom,
        (perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom,
        (-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom,
        (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom,
        (-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom,
        (-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom,
        (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
    ]

1616
1617
    theta1 = torch.tensor(
        [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
1618
1619
1620
1621
        dtype=dtype,
        device=device,
    )

1622
1623
1624
1625
    theta2 = torch.tensor(
        [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device
    )

1626
1627
1628
1629
    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
1630
    points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1631
1632
1633
1634
1635
    points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
    # 2) Now let's transform the points using perspective matrices
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)

1636
1637
    numer_points = torch.matmul(points, theta1.T)
    denom_points = torch.matmul(points, theta2.T)
1638
    transformed_points = numer_points.div_(denom_points)
1639
1640
1641

    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
1642
    transformed_points = transformed_points.reshape(-1, 4, 2)
1643
1644
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)

1645
1646
    out_bboxes = clamp_bounding_boxes(
        torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
1647
        format=tv_tensors.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
1648
        canvas_size=canvas_size,
1649
    )
1650
1651
1652

    # out_bboxes should be of shape [N boxes, 4]

Nicolas Hug's avatar
Nicolas Hug committed
1653
    return convert_bounding_box_format(
1654
        out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
1655
    ).reshape(original_shape)
1656
1657


1658
@_register_kernel_internal(perspective, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
1659
def _perspective_bounding_boxes_dispatch(
1660
    inpt: tv_tensors.BoundingBoxes,
1661
1662
1663
1664
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
    coefficients: Optional[List[float]] = None,
    **kwargs,
1665
) -> tv_tensors.BoundingBoxes:
1666
1667
1668
1669
1670
1671
1672
1673
    output = perspective_bounding_boxes(
        inpt.as_subclass(torch.Tensor),
        format=inpt.format,
        canvas_size=inpt.canvas_size,
        startpoints=startpoints,
        endpoints=endpoints,
        coefficients=coefficients,
    )
1674
    return tv_tensors.wrap(output, like=inpt)
1675
1676


1677
1678
def perspective_mask(
    mask: torch.Tensor,
1679
1680
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1681
    fill: _FillTypeJIT = None,
1682
    coefficients: Optional[List[float]] = None,
1683
) -> torch.Tensor:
1684
1685
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1686
1687
1688
1689
        needs_squeeze = True
    else:
        needs_squeeze = False

1690
    output = perspective_image(
1691
        mask, startpoints, endpoints, interpolation=InterpolationMode.NEAREST, fill=fill, coefficients=coefficients
1692
    )
1693

1694
1695
1696
1697
1698
    if needs_squeeze:
        output = output.squeeze(0)

    return output

1699

1700
@_register_kernel_internal(perspective, tv_tensors.Mask, tv_tensor_wrapper=False)
1701
def _perspective_mask_dispatch(
1702
    inpt: tv_tensors.Mask,
1703
1704
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1705
    fill: _FillTypeJIT = None,
1706
1707
    coefficients: Optional[List[float]] = None,
    **kwargs,
1708
) -> tv_tensors.Mask:
1709
1710
1711
1712
1713
1714
1715
    output = perspective_mask(
        inpt.as_subclass(torch.Tensor),
        startpoints=startpoints,
        endpoints=endpoints,
        fill=fill,
        coefficients=coefficients,
    )
1716
    return tv_tensors.wrap(output, like=inpt)
1717
1718


1719
@_register_kernel_internal(perspective, tv_tensors.Video)
1720
1721
def perspective_video(
    video: torch.Tensor,
1722
1723
    startpoints: Optional[List[List[int]]],
    endpoints: Optional[List[List[int]]],
1724
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1725
    fill: _FillTypeJIT = None,
1726
    coefficients: Optional[List[float]] = None,
1727
) -> torch.Tensor:
1728
    return perspective_image(
1729
1730
        video, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
    )
1731
1732


1733
def elastic(
1734
    inpt: torch.Tensor,
1735
    displacement: torch.Tensor,
1736
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1737
1738
    fill: _FillTypeJIT = None,
) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
1739
    """[BETA] See :class:`~torchvision.transforms.v2.ElasticTransform` for details."""
1740
    if torch.jit.is_scripting():
1741
        return elastic_image(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
1742
1743
1744
1745
1746

    _log_api_usage_once(elastic)

    kernel = _get_kernel(elastic, type(inpt))
    return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
1747
1748


1749
1750
1751
elastic_transform = elastic


1752
@_register_kernel_internal(elastic, torch.Tensor)
1753
@_register_kernel_internal(elastic, tv_tensors.Image)
1754
def elastic_image(
1755
    image: torch.Tensor,
1756
    displacement: torch.Tensor,
1757
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1758
    fill: _FillTypeJIT = None,
1759
) -> torch.Tensor:
Philip Meier's avatar
Philip Meier committed
1760
1761
1762
    if not isinstance(displacement, torch.Tensor):
        raise TypeError("Argument displacement should be a Tensor")

1763
1764
    interpolation = _check_interpolation(interpolation)

1765
1766
1767
1768
    if image.numel() == 0:
        return image

    shape = image.shape
1769
    ndim = image.ndim
1770

1771
    device = image.device
1772
    dtype = image.dtype if torch.is_floating_point(image) else torch.float32
1773
1774
1775
1776
1777
1778
1779

    # Patch: elastic transform should support (cpu,f16) input
    is_cpu_half = device.type == "cpu" and dtype == torch.float16
    if is_cpu_half:
        image = image.to(torch.float32)
        dtype = torch.float32

1780
1781
1782
    # We are aware that if input image dtype is uint8 and displacement is float64 then
    # displacement will be casted to float32 and all computations will be done with float32
    # We can fix this later if needed
1783

1784
1785
1786
1787
    expected_shape = (1,) + shape[-2:] + (2,)
    if expected_shape != displacement.shape:
        raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")

1788
    if ndim > 4:
1789
        image = image.reshape((-1,) + shape[-3:])
1790
        needs_unsquash = True
1791
1792
1793
    elif ndim == 3:
        image = image.unsqueeze(0)
        needs_unsquash = True
1794
1795
1796
    else:
        needs_unsquash = False

1797
1798
    if displacement.dtype != dtype or displacement.device != device:
        displacement = displacement.to(dtype=dtype, device=device)
1799

1800
1801
1802
    image_height, image_width = shape[-2:]
    grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement)
    output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
1803
1804

    if needs_unsquash:
1805
        output = output.reshape(shape)
1806

1807
1808
1809
    if is_cpu_half:
        output = output.to(torch.float16)

1810
    return output
1811
1812


1813
@_register_kernel_internal(elastic, PIL.Image.Image)
1814
def _elastic_image_pil(
1815
    image: PIL.Image.Image,
1816
    displacement: torch.Tensor,
1817
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1818
    fill: _FillTypeJIT = None,
1819
) -> PIL.Image.Image:
1820
    t_img = pil_to_tensor(image)
1821
    output = elastic_image(t_img, displacement, interpolation=interpolation, fill=fill)
1822
    return to_pil_image(output, mode=image.mode)
1823
1824


1825
def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
1826
    sy, sx = size
1827
1828
    base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype)
    x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype)
1829
1830
    base_grid[..., 0].copy_(x_grid)

1831
    y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1)
1832
1833
1834
1835
1836
    base_grid[..., 1].copy_(y_grid)

    return base_grid


1837
1838
def elastic_bounding_boxes(
    bounding_boxes: torch.Tensor,
1839
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
1840
    canvas_size: Tuple[int, int],
1841
1842
    displacement: torch.Tensor,
) -> torch.Tensor:
Philip Meier's avatar
Philip Meier committed
1843
1844
1845
1846
1847
1848
    expected_shape = (1, canvas_size[0], canvas_size[1], 2)
    if not isinstance(displacement, torch.Tensor):
        raise TypeError("Argument displacement should be a Tensor")
    elif displacement.shape != expected_shape:
        raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")

1849
1850
    if bounding_boxes.numel() == 0:
        return bounding_boxes
1851

1852
    # TODO: add in docstring about approximation we are doing for grid inversion
1853
1854
    device = bounding_boxes.device
    dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32
1855
1856
1857

    if displacement.dtype != dtype or displacement.device != device:
        displacement = displacement.to(dtype=dtype, device=device)
1858

1859
    original_shape = bounding_boxes.shape
Nicolas Hug's avatar
Nicolas Hug committed
1860
    # TODO: first cast to float if bbox is int64 before convert_bounding_box_format
1861
    bounding_boxes = (
1862
        convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
1863
    ).reshape(-1, 4)
1864

Philip Meier's avatar
Philip Meier committed
1865
    id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
1866
1867
    # We construct an approximation of inverse grid as inv_grid = id_grid - displacement
    # This is not an exact inverse of the grid
1868
    inv_grid = id_grid.sub_(displacement)
1869
1870

    # Get points from bboxes
1871
    points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1872
1873
1874
1875
1876
    if points.is_floating_point():
        points = points.ceil_()
    index_xy = points.to(dtype=torch.long)
    index_x, index_y = index_xy[:, 0], index_xy[:, 1]

1877
    # Transform points:
Philip Meier's avatar
Philip Meier committed
1878
    t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype)
1879
    transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
1880

1881
    transformed_points = transformed_points.reshape(-1, 4, 2)
1882
    out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
1883
1884
    out_bboxes = clamp_bounding_boxes(
        torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
1885
        format=tv_tensors.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
1886
        canvas_size=canvas_size,
1887
    )
1888

Nicolas Hug's avatar
Nicolas Hug committed
1889
    return convert_bounding_box_format(
1890
        out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
1891
    ).reshape(original_shape)
1892
1893


1894
@_register_kernel_internal(elastic, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
1895
def _elastic_bounding_boxes_dispatch(
1896
1897
    inpt: tv_tensors.BoundingBoxes, displacement: torch.Tensor, **kwargs
) -> tv_tensors.BoundingBoxes:
1898
1899
1900
    output = elastic_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, displacement=displacement
    )
1901
    return tv_tensors.wrap(output, like=inpt)
1902
1903


1904
1905
1906
def elastic_mask(
    mask: torch.Tensor,
    displacement: torch.Tensor,
1907
    fill: _FillTypeJIT = None,
1908
) -> torch.Tensor:
1909
1910
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
1911
1912
1913
1914
        needs_squeeze = True
    else:
        needs_squeeze = False

1915
    output = elastic_image(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill)
1916
1917
1918
1919
1920

    if needs_squeeze:
        output = output.squeeze(0)

    return output
1921
1922


1923
@_register_kernel_internal(elastic, tv_tensors.Mask, tv_tensor_wrapper=False)
1924
def _elastic_mask_dispatch(
1925
1926
    inpt: tv_tensors.Mask, displacement: torch.Tensor, fill: _FillTypeJIT = None, **kwargs
) -> tv_tensors.Mask:
1927
    output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill)
1928
    return tv_tensors.wrap(output, like=inpt)
1929
1930


1931
@_register_kernel_internal(elastic, tv_tensors.Video)
1932
1933
1934
def elastic_video(
    video: torch.Tensor,
    displacement: torch.Tensor,
1935
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1936
    fill: _FillTypeJIT = None,
1937
) -> torch.Tensor:
1938
    return elastic_image(video, displacement, interpolation=interpolation, fill=fill)
1939
1940


1941
def center_crop(inpt: torch.Tensor, output_size: List[int]) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
1942
    """[BETA] See :class:`~torchvision.transforms.v2.RandomCrop` for details."""
1943
    if torch.jit.is_scripting():
1944
        return center_crop_image(inpt, output_size=output_size)
1945
1946
1947
1948
1949

    _log_api_usage_once(center_crop)

    kernel = _get_kernel(center_crop, type(inpt))
    return kernel(inpt, output_size=output_size)
1950
1951


1952
1953
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
    if isinstance(output_size, numbers.Number):
1954
1955
        s = int(output_size)
        return [s, s]
1956
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
1957
        return [output_size[0], output_size[0]]
1958
1959
    else:
        return list(output_size)
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978


def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]:
    return [
        (crop_width - image_width) // 2 if crop_width > image_width else 0,
        (crop_height - image_height) // 2 if crop_height > image_height else 0,
        (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
        (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
    ]


def _center_crop_compute_crop_anchor(
    crop_height: int, crop_width: int, image_height: int, image_width: int
) -> Tuple[int, int]:
    crop_top = int(round((image_height - crop_height) / 2.0))
    crop_left = int(round((image_width - crop_width) / 2.0))
    return crop_top, crop_left


1979
@_register_kernel_internal(center_crop, torch.Tensor)
1980
@_register_kernel_internal(center_crop, tv_tensors.Image)
1981
def center_crop_image(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
1982
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
1983
1984
1985
1986
    shape = image.shape
    if image.numel() == 0:
        return image.reshape(shape[:-2] + (crop_height, crop_width))
    image_height, image_width = shape[-2:]
1987
1988
1989

    if crop_height > image_height or crop_width > image_width:
        padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
1990
        image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0)
1991

1992
        image_height, image_width = image.shape[-2:]
1993
        if crop_width == image_width and crop_height == image_height:
1994
            return image
1995
1996

    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
1997
    return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]
1998
1999


2000
@_register_kernel_internal(center_crop, PIL.Image.Image)
2001
def _center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
2002
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
2003
    image_height, image_width = _get_size_image_pil(image)
2004
2005
2006

    if crop_height > image_height or crop_width > image_width:
        padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
2007
        image = _pad_image_pil(image, padding_ltrb, fill=0)
2008

2009
        image_height, image_width = _get_size_image_pil(image)
2010
        if crop_width == image_width and crop_height == image_height:
2011
            return image
2012
2013

    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
2014
    return _crop_image_pil(image, crop_top, crop_left, crop_height, crop_width)
2015
2016


2017
2018
def center_crop_bounding_boxes(
    bounding_boxes: torch.Tensor,
2019
    format: tv_tensors.BoundingBoxFormat,
Philip Meier's avatar
Philip Meier committed
2020
    canvas_size: Tuple[int, int],
2021
    output_size: List[int],
2022
) -> Tuple[torch.Tensor, Tuple[int, int]]:
2023
    crop_height, crop_width = _center_crop_parse_output_size(output_size)
Philip Meier's avatar
Philip Meier committed
2024
    crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size)
2025
2026
2027
    return crop_bounding_boxes(
        bounding_boxes, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width
    )
2028
2029


2030
@_register_kernel_internal(center_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
2031
def _center_crop_bounding_boxes_dispatch(
2032
2033
    inpt: tv_tensors.BoundingBoxes, output_size: List[int]
) -> tv_tensors.BoundingBoxes:
2034
2035
2036
    output, canvas_size = center_crop_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size
    )
2037
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
2038
2039


2040
@_register_kernel_internal(center_crop, tv_tensors.Mask)
2041
2042
2043
def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
    if mask.ndim < 3:
        mask = mask.unsqueeze(0)
2044
2045
2046
2047
        needs_squeeze = True
    else:
        needs_squeeze = False

2048
    output = center_crop_image(image=mask, output_size=output_size)
2049
2050
2051
2052
2053

    if needs_squeeze:
        output = output.squeeze(0)

    return output
2054
2055


2056
@_register_kernel_internal(center_crop, tv_tensors.Video)
2057
def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor:
2058
    return center_crop_image(video, output_size)
2059
2060


2061
def resized_crop(
2062
    inpt: torch.Tensor,
2063
2064
2065
2066
2067
2068
2069
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
    antialias: Optional[Union[str, bool]] = "warn",
2070
) -> torch.Tensor:
Nicolas Hug's avatar
Nicolas Hug committed
2071
    """[BETA] See :class:`~torchvision.transforms.v2.RandomResizedCrop` for details."""
2072
    if torch.jit.is_scripting():
2073
        return resized_crop_image(
2074
2075
2076
2077
2078
2079
2080
2081
            inpt,
            top=top,
            left=left,
            height=height,
            width=width,
            size=size,
            interpolation=interpolation,
            antialias=antialias,
2082
        )
2083

2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
    _log_api_usage_once(resized_crop)

    kernel = _get_kernel(resized_crop, type(inpt))
    return kernel(
        inpt,
        top=top,
        left=left,
        height=height,
        width=width,
        size=size,
        interpolation=interpolation,
        antialias=antialias,
    )
2097

2098
2099

@_register_kernel_internal(resized_crop, torch.Tensor)
2100
@_register_kernel_internal(resized_crop, tv_tensors.Image)
2101
def resized_crop_image(
2102
    image: torch.Tensor,
2103
2104
2105
2106
2107
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
2108
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2109
    antialias: Optional[Union[str, bool]] = "warn",
2110
) -> torch.Tensor:
2111
2112
    image = crop_image(image, top, left, height, width)
    return resize_image(image, size, interpolation=interpolation, antialias=antialias)
2113
2114


2115
def _resized_crop_image_pil(
2116
    image: PIL.Image.Image,
2117
2118
2119
2120
2121
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
2122
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2123
) -> PIL.Image.Image:
2124
2125
    image = _crop_image_pil(image, top, left, height, width)
    return _resize_image_pil(image, size, interpolation=interpolation)
2126
2127


2128
@_register_kernel_internal(resized_crop, PIL.Image.Image)
2129
def _resized_crop_image_pil_dispatch(
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
    image: PIL.Image.Image,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
    antialias: Optional[Union[str, bool]] = "warn",
) -> PIL.Image.Image:
    if antialias is False:
        warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
2141
    return _resized_crop_image_pil(
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
        image,
        top=top,
        left=left,
        height=height,
        width=width,
        size=size,
        interpolation=interpolation,
    )


2152
2153
def resized_crop_bounding_boxes(
    bounding_boxes: torch.Tensor,
2154
    format: tv_tensors.BoundingBoxFormat,
2155
2156
2157
2158
2159
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
2160
) -> Tuple[torch.Tensor, Tuple[int, int]]:
2161
2162
2163
2164
    bounding_boxes, canvas_size = crop_bounding_boxes(bounding_boxes, format, top, left, height, width)
    return resize_bounding_boxes(bounding_boxes, canvas_size=canvas_size, size=size)


2165
@_register_kernel_internal(resized_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
2166
def _resized_crop_bounding_boxes_dispatch(
2167
2168
    inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: List[int], **kwargs
) -> tv_tensors.BoundingBoxes:
2169
2170
2171
    output, canvas_size = resized_crop_bounding_boxes(
        inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size
    )
2172
    return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
2173
2174


2175
def resized_crop_mask(
2176
2177
2178
2179
2180
2181
2182
    mask: torch.Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
) -> torch.Tensor:
2183
2184
    mask = crop_mask(mask, top, left, height, width)
    return resize_mask(mask, size)
2185
2186


2187
@_register_kernel_internal(resized_crop, tv_tensors.Mask, tv_tensor_wrapper=False)
2188
def _resized_crop_mask_dispatch(
2189
2190
    inpt: tv_tensors.Mask, top: int, left: int, height: int, width: int, size: List[int], **kwargs
) -> tv_tensors.Mask:
2191
2192
2193
    output = resized_crop_mask(
        inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size
    )
2194
    return tv_tensors.wrap(output, like=inpt)
2195
2196


2197
@_register_kernel_internal(resized_crop, tv_tensors.Video)
2198
2199
2200
2201
2202
2203
2204
def resized_crop_video(
    video: torch.Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
    size: List[int],
2205
    interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
2206
    antialias: Optional[Union[str, bool]] = "warn",
2207
) -> torch.Tensor:
2208
    return resized_crop_image(
2209
2210
2211
2212
        video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
    )


2213
def five_crop(
2214
2215
    inpt: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
Nicolas Hug's avatar
Nicolas Hug committed
2216
    """[BETA] See :class:`~torchvision.transforms.v2.FiveCrop` for details."""
2217
    if torch.jit.is_scripting():
2218
        return five_crop_image(inpt, size=size)
2219
2220
2221
2222
2223

    _log_api_usage_once(five_crop)

    kernel = _get_kernel(five_crop, type(inpt))
    return kernel(inpt, size=size)
2224
2225


2226
2227
def _parse_five_crop_size(size: List[int]) -> List[int]:
    if isinstance(size, numbers.Number):
2228
2229
        s = int(size)
        size = [s, s]
2230
    elif isinstance(size, (tuple, list)) and len(size) == 1:
2231
2232
        s = size[0]
        size = [s, s]
2233
2234
2235
2236
2237
2238
2239

    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")

    return size


2240
@_register_five_ten_crop_kernel_internal(five_crop, torch.Tensor)
2241
@_register_five_ten_crop_kernel_internal(five_crop, tv_tensors.Image)
2242
def five_crop_image(
2243
    image: torch.Tensor, size: List[int]
2244
2245
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    crop_height, crop_width = _parse_five_crop_size(size)
2246
    image_height, image_width = image.shape[-2:]
2247
2248

    if crop_width > image_width or crop_height > image_height:
2249
        raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
2250

2251
2252
2253
2254
2255
    tl = crop_image(image, 0, 0, crop_height, crop_width)
    tr = crop_image(image, 0, image_width - crop_width, crop_height, crop_width)
    bl = crop_image(image, image_height - crop_height, 0, crop_height, crop_width)
    br = crop_image(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
    center = center_crop_image(image, [crop_height, crop_width])
2256
2257
2258
2259

    return tl, tr, bl, br, center


2260
@_register_five_ten_crop_kernel_internal(five_crop, PIL.Image.Image)
2261
def _five_crop_image_pil(
2262
    image: PIL.Image.Image, size: List[int]
2263
2264
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
    crop_height, crop_width = _parse_five_crop_size(size)
2265
    image_height, image_width = _get_size_image_pil(image)
2266
2267

    if crop_width > image_width or crop_height > image_height:
2268
        raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
2269

2270
2271
2272
2273
2274
    tl = _crop_image_pil(image, 0, 0, crop_height, crop_width)
    tr = _crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)
    bl = _crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width)
    br = _crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
    center = _center_crop_image_pil(image, [crop_height, crop_width])
2275
2276
2277
2278

    return tl, tr, bl, br, center


2279
@_register_five_ten_crop_kernel_internal(five_crop, tv_tensors.Video)
2280
2281
2282
def five_crop_video(
    video: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
2283
    return five_crop_image(video, size)
2284
2285


2286
def ten_crop(
2287
    inpt: torch.Tensor, size: List[int], vertical_flip: bool = False
2288
) -> Tuple[
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
2299
]:
Nicolas Hug's avatar
Nicolas Hug committed
2300
    """[BETA] See :class:`~torchvision.transforms.v2.TenCrop` for details."""
2301
    if torch.jit.is_scripting():
2302
        return ten_crop_image(inpt, size=size, vertical_flip=vertical_flip)
2303
2304
2305
2306
2307

    _log_api_usage_once(ten_crop)

    kernel = _get_kernel(ten_crop, type(inpt))
    return kernel(inpt, size=size, vertical_flip=vertical_flip)
2308
2309


2310
@_register_five_ten_crop_kernel_internal(ten_crop, torch.Tensor)
2311
@_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Image)
2312
def ten_crop_image(
Philip Meier's avatar
Philip Meier committed
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
    image: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
2326
    non_flipped = five_crop_image(image, size)
2327
2328

    if vertical_flip:
2329
        image = vertical_flip_image(image)
2330
    else:
2331
        image = horizontal_flip_image(image)
2332

2333
    flipped = five_crop_image(image, size)
2334

Philip Meier's avatar
Philip Meier committed
2335
    return non_flipped + flipped
2336
2337


2338
@_register_five_ten_crop_kernel_internal(ten_crop, PIL.Image.Image)
2339
def _ten_crop_image_pil(
Philip Meier's avatar
Philip Meier committed
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
    image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
) -> Tuple[
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
    PIL.Image.Image,
]:
2353
    non_flipped = _five_crop_image_pil(image, size)
2354
2355

    if vertical_flip:
2356
        image = _vertical_flip_image_pil(image)
2357
    else:
2358
        image = _horizontal_flip_image_pil(image)
2359

2360
    flipped = _five_crop_image_pil(image, size)
Philip Meier's avatar
Philip Meier committed
2361
2362
2363
2364

    return non_flipped + flipped


2365
@_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Video)
Philip Meier's avatar
Philip Meier committed
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
def ten_crop_video(
    video: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
2380
    return ten_crop_image(video, size, vertical_flip=vertical_flip)