_meta.py 10.3 KB
Newer Older
1
from typing import List, Optional, Tuple
2

3
4
import PIL.Image
import torch
5
from torchvision import tv_tensors
6
from torchvision.transforms import _functional_pil as _FP
7
from torchvision.tv_tensors import BoundingBoxFormat
8

9
10
from torchvision.utils import _log_api_usage_once

11
from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor
12

13

14
def get_dimensions(inpt: torch.Tensor) -> List[int]:
15
    if torch.jit.is_scripting():
16
        return get_dimensions_image(inpt)
17
18
19
20
21

    _log_api_usage_once(get_dimensions)

    kernel = _get_kernel(get_dimensions, type(inpt))
    return kernel(inpt)
22
23


24
@_register_kernel_internal(get_dimensions, torch.Tensor)
25
@_register_kernel_internal(get_dimensions, tv_tensors.Image, tv_tensor_wrapper=False)
26
def get_dimensions_image(image: torch.Tensor) -> List[int]:
27
28
29
30
31
32
33
34
35
36
37
    chw = list(image.shape[-3:])
    ndims = len(chw)
    if ndims == 3:
        return chw
    elif ndims == 2:
        chw.insert(0, 1)
        return chw
    else:
        raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")


38
_get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions)
39
40


41
@_register_kernel_internal(get_dimensions, tv_tensors.Video, tv_tensor_wrapper=False)
Philip Meier's avatar
Philip Meier committed
42
def get_dimensions_video(video: torch.Tensor) -> List[int]:
43
    return get_dimensions_image(video)
Philip Meier's avatar
Philip Meier committed
44
45


46
def get_num_channels(inpt: torch.Tensor) -> int:
47
    if torch.jit.is_scripting():
48
        return get_num_channels_image(inpt)
49
50
51
52
53

    _log_api_usage_once(get_num_channels)

    kernel = _get_kernel(get_num_channels, type(inpt))
    return kernel(inpt)
54
55


56
@_register_kernel_internal(get_num_channels, torch.Tensor)
57
@_register_kernel_internal(get_num_channels, tv_tensors.Image, tv_tensor_wrapper=False)
58
def get_num_channels_image(image: torch.Tensor) -> int:
59
60
61
62
63
64
65
66
67
68
    chw = image.shape[-3:]
    ndims = len(chw)
    if ndims == 3:
        return chw[0]
    elif ndims == 2:
        return 1
    else:
        raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")


69
_get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels)
70
71


72
@_register_kernel_internal(get_num_channels, tv_tensors.Video, tv_tensor_wrapper=False)
73
def get_num_channels_video(video: torch.Tensor) -> int:
74
    return get_num_channels_image(video)
75
76


77
78
79
80
81
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
# deprecating the old names.
get_image_num_channels = get_num_channels


82
def get_size(inpt: torch.Tensor) -> List[int]:
83
    if torch.jit.is_scripting():
84
        return get_size_image(inpt)
85
86
87
88
89

    _log_api_usage_once(get_size)

    kernel = _get_kernel(get_size, type(inpt))
    return kernel(inpt)
90
91


92
@_register_kernel_internal(get_size, torch.Tensor)
93
@_register_kernel_internal(get_size, tv_tensors.Image, tv_tensor_wrapper=False)
94
def get_size_image(image: torch.Tensor) -> List[int]:
95
96
97
98
99
100
    hw = list(image.shape[-2:])
    ndims = len(hw)
    if ndims == 2:
        return hw
    else:
        raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
101
102


103
@_register_kernel_internal(get_size, PIL.Image.Image)
104
def _get_size_image_pil(image: PIL.Image.Image) -> List[int]:
105
106
107
108
    width, height = _FP.get_image_size(image)
    return [height, width]


109
@_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False)
Philip Meier's avatar
Philip Meier committed
110
def get_size_video(video: torch.Tensor) -> List[int]:
111
    return get_size_image(video)
112
113


114
@_register_kernel_internal(get_size, tv_tensors.Mask, tv_tensor_wrapper=False)
Philip Meier's avatar
Philip Meier committed
115
def get_size_mask(mask: torch.Tensor) -> List[int]:
116
    return get_size_image(mask)
117
118


119
120
@_register_kernel_internal(get_size, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
def get_size_bounding_boxes(bounding_box: tv_tensors.BoundingBoxes) -> List[int]:
Philip Meier's avatar
Philip Meier committed
121
    return list(bounding_box.canvas_size)
122
123


124
def get_num_frames(inpt: torch.Tensor) -> int:
125
    if torch.jit.is_scripting():
126
        return get_num_frames_video(inpt)
127
128
129
130
131

    _log_api_usage_once(get_num_frames)

    kernel = _get_kernel(get_num_frames, type(inpt))
    return kernel(inpt)
132
133


134
@_register_kernel_internal(get_num_frames, torch.Tensor)
135
@_register_kernel_internal(get_num_frames, tv_tensors.Video, tv_tensor_wrapper=False)
136
137
def get_num_frames_video(video: torch.Tensor) -> int:
    return video.shape[-4]
138
139


140
141
def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor:
    xyxy = xywh if inplace else xywh.clone()
142
143
144
145
    xyxy[..., 2:] += xyxy[..., :2]
    return xyxy


146
147
def _xyxy_to_xywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
    xywh = xyxy if inplace else xyxy.clone()
148
149
150
151
    xywh[..., 2:] -= xywh[..., :2]
    return xywh


152
153
154
def _cxcywh_to_xyxy(cxcywh: torch.Tensor, inplace: bool) -> torch.Tensor:
    if not inplace:
        cxcywh = cxcywh.clone()
155

156
157
158
159
160
161
162
    # Trick to do fast division by 2 and ceil, without casting. It produces the same result as
    # `torchvision.ops._box_convert._box_cxcywh_to_xyxy`.
    half_wh = cxcywh[..., 2:].div(-2, rounding_mode=None if cxcywh.is_floating_point() else "floor").abs_()
    # (cx - width / 2) = x1, same for y1
    cxcywh[..., :2].sub_(half_wh)
    # (x1 + width) = x2, same for y2
    cxcywh[..., 2:].add_(cxcywh[..., :2])
163

164
165
166
167
168
169
170
171
172
173
174
175
176
    return cxcywh


def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
    if not inplace:
        xyxy = xyxy.clone()

    # (x2 - x1) = width, same for height
    xyxy[..., 2:].sub_(xyxy[..., :2])
    # (x1 * 2 + width) / 2 = x1 + width / 2 = x1 + (x2-x1)/2 = (x1 + x2)/2 = cx, same for cy
    xyxy[..., :2].mul_(2).add_(xyxy[..., 2:]).div_(2, rounding_mode=None if xyxy.is_floating_point() else "floor")

    return xyxy
177
178


Nicolas Hug's avatar
Nicolas Hug committed
179
def _convert_bounding_box_format(
180
    bounding_boxes: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False
181
) -> torch.Tensor:
182

183
    if new_format == old_format:
184
        return bounding_boxes
185

186
    # TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance
187
    if old_format == BoundingBoxFormat.XYWH:
188
        bounding_boxes = _xywh_to_xyxy(bounding_boxes, inplace)
189
    elif old_format == BoundingBoxFormat.CXCYWH:
190
        bounding_boxes = _cxcywh_to_xyxy(bounding_boxes, inplace)
191
192

    if new_format == BoundingBoxFormat.XYWH:
193
        bounding_boxes = _xyxy_to_xywh(bounding_boxes, inplace)
194
    elif new_format == BoundingBoxFormat.CXCYWH:
195
        bounding_boxes = _xyxy_to_cxcywh(bounding_boxes, inplace)
196

197
    return bounding_boxes
198
199


Nicolas Hug's avatar
Nicolas Hug committed
200
def convert_bounding_box_format(
201
    inpt: torch.Tensor,
202
203
204
    old_format: Optional[BoundingBoxFormat] = None,
    new_format: Optional[BoundingBoxFormat] = None,
    inplace: bool = False,
205
) -> torch.Tensor:
206
    """See :func:`~torchvision.transforms.v2.ConvertBoundingBoxFormat` for details."""
207
    # This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for pure tensor
208
    # inputs as well as extract it from `tv_tensors.BoundingBoxes` inputs. However, putting a default value on
209
210
211
    # `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
    # default error that would be thrown if `new_format` had no default value.
    if new_format is None:
Nicolas Hug's avatar
Nicolas Hug committed
212
        raise TypeError("convert_bounding_box_format() missing 1 required argument: 'new_format'")
213
214

    if not torch.jit.is_scripting():
Nicolas Hug's avatar
Nicolas Hug committed
215
        _log_api_usage_once(convert_bounding_box_format)
216

217
218
219
220
221
    if isinstance(old_format, str):
        old_format = BoundingBoxFormat[old_format.upper()]
    if isinstance(new_format, str):
        new_format = BoundingBoxFormat[new_format.upper()]

222
    if torch.jit.is_scripting() or is_pure_tensor(inpt):
223
        if old_format is None:
224
            raise ValueError("For pure tensor inputs, `old_format` has to be passed.")
Nicolas Hug's avatar
Nicolas Hug committed
225
        return _convert_bounding_box_format(inpt, old_format=old_format, new_format=new_format, inplace=inplace)
226
    elif isinstance(inpt, tv_tensors.BoundingBoxes):
227
        if old_format is not None:
228
            raise ValueError("For bounding box tv_tensor inputs, `old_format` must not be passed.")
Nicolas Hug's avatar
Nicolas Hug committed
229
        output = _convert_bounding_box_format(
230
231
            inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace
        )
232
        return tv_tensors.wrap(output, like=inpt, format=new_format)
233
234
    else:
        raise TypeError(
235
            f"Input can either be a plain tensor or a bounding box tv_tensor, but got {type(inpt)} instead."
236
237
238
        )


239
def _clamp_bounding_boxes(
Philip Meier's avatar
Philip Meier committed
240
    bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: Tuple[int, int]
241
) -> torch.Tensor:
242
243
    # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
    #  BoundingBoxFormat instead of converting back and forth
244
245
    in_dtype = bounding_boxes.dtype
    bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
Nicolas Hug's avatar
Nicolas Hug committed
246
    xyxy_boxes = convert_bounding_box_format(
247
        bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True
248
    )
Philip Meier's avatar
Philip Meier committed
249
250
    xyxy_boxes[..., 0::2].clamp_(min=0, max=canvas_size[1])
    xyxy_boxes[..., 1::2].clamp_(min=0, max=canvas_size[0])
Nicolas Hug's avatar
Nicolas Hug committed
251
    out_boxes = convert_bounding_box_format(
252
253
254
        xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True
    )
    return out_boxes.to(in_dtype)
255
256


257
def clamp_bounding_boxes(
258
    inpt: torch.Tensor,
259
    format: Optional[BoundingBoxFormat] = None,
Philip Meier's avatar
Philip Meier committed
260
    canvas_size: Optional[Tuple[int, int]] = None,
261
) -> torch.Tensor:
262
    """See :func:`~torchvision.transforms.v2.ClampBoundingBoxes` for details."""
263
    if not torch.jit.is_scripting():
264
        _log_api_usage_once(clamp_bounding_boxes)
265

266
    if torch.jit.is_scripting() or is_pure_tensor(inpt):
Philip Meier's avatar
Philip Meier committed
267
268

        if format is None or canvas_size is None:
269
            raise ValueError("For pure tensor inputs, `format` and `canvas_size` have to be passed.")
Philip Meier's avatar
Philip Meier committed
270
        return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size)
271
    elif isinstance(inpt, tv_tensors.BoundingBoxes):
Philip Meier's avatar
Philip Meier committed
272
        if format is not None or canvas_size is not None:
273
            raise ValueError("For bounding box tv_tensor inputs, `format` and `canvas_size` must not be passed.")
Philip Meier's avatar
Philip Meier committed
274
        output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size)
275
        return tv_tensors.wrap(output, like=inpt)
276
277
    else:
        raise TypeError(
278
            f"Input can either be a plain tensor or a bounding box tv_tensor, but got {type(inpt)} instead."
279
        )