_meta.py 11.6 KB
Newer Older
1
from typing import List, Optional, Tuple, Union
2

3
4
import PIL.Image
import torch
5
6
from torchvision import datapoints
from torchvision.datapoints import BoundingBoxFormat
7
from torchvision.transforms import _functional_pil as _FP
8

9
10
from torchvision.utils import _log_api_usage_once

11
from ._utils import _get_kernel, _register_kernel_internal, _register_unsupported_type, is_simple_tensor
12

13

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask)
def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]:
    if not torch.jit.is_scripting():
        _log_api_usage_once(get_dimensions)

    if torch.jit.is_scripting() or is_simple_tensor(inpt):
        return get_dimensions_image_tensor(inpt)
    elif isinstance(inpt, datapoints.Datapoint):
        kernel = _get_kernel(get_dimensions, type(inpt))
        return kernel(inpt)
    elif isinstance(inpt, PIL.Image.Image):
        return get_dimensions_image_pil(inpt)
    else:
        raise TypeError(
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
            f"but got {type(inpt)} instead."
        )


@_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False)
34
35
36
37
38
39
40
41
42
43
44
45
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
    chw = list(image.shape[-3:])
    ndims = len(chw)
    if ndims == 3:
        return chw
    elif ndims == 2:
        chw.insert(0, 1)
        return chw
    else:
        raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")


46
47
48
get_dimensions_image_pil = _FP.get_dimensions


49
@_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False)
Philip Meier's avatar
Philip Meier committed
50
51
52
53
def get_dimensions_video(video: torch.Tensor) -> List[int]:
    return get_dimensions_image_tensor(video)


54
55
@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask)
def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int:
56
    if not torch.jit.is_scripting():
57
        _log_api_usage_once(get_num_channels)
58

59
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
60
61
62
63
64
65
66
67
68
69
70
        return get_num_channels_image_tensor(inpt)
    elif isinstance(inpt, datapoints.Datapoint):
        kernel = _get_kernel(get_num_channels, type(inpt))
        return kernel(inpt)
    elif isinstance(inpt, PIL.Image.Image):
        return get_num_channels_image_pil(inpt)
    else:
        raise TypeError(
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
            f"but got {type(inpt)} instead."
        )
71
72


73
@_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False)
74
75
76
77
78
79
80
81
82
83
84
def get_num_channels_image_tensor(image: torch.Tensor) -> int:
    chw = image.shape[-3:]
    ndims = len(chw)
    if ndims == 3:
        return chw[0]
    elif ndims == 2:
        return 1
    else:
        raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")


85
get_num_channels_image_pil = _FP.get_image_num_channels
86
87


88
@_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False)
89
90
91
92
def get_num_channels_video(video: torch.Tensor) -> int:
    return get_num_channels_image_tensor(video)


93
94
95
96
97
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
# deprecating the old names.
get_image_num_channels = get_num_channels


98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def get_size(inpt: datapoints._InputTypeJIT) -> List[int]:
    if not torch.jit.is_scripting():
        _log_api_usage_once(get_size)

    if torch.jit.is_scripting() or is_simple_tensor(inpt):
        return get_size_image_tensor(inpt)
    elif isinstance(inpt, datapoints.Datapoint):
        kernel = _get_kernel(get_size, type(inpt))
        return kernel(inpt)
    elif isinstance(inpt, PIL.Image.Image):
        return get_size_image_pil(inpt)
    else:
        raise TypeError(
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
            f"but got {type(inpt)} instead."
        )


@_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False)
Philip Meier's avatar
Philip Meier committed
117
def get_size_image_tensor(image: torch.Tensor) -> List[int]:
118
119
120
121
122
123
    hw = list(image.shape[-2:])
    ndims = len(hw)
    if ndims == 2:
        return hw
    else:
        raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
124
125
126


@torch.jit.unused
Philip Meier's avatar
Philip Meier committed
127
def get_size_image_pil(image: PIL.Image.Image) -> List[int]:
128
129
130
131
    width, height = _FP.get_image_size(image)
    return [height, width]


132
@_register_kernel_internal(get_size, datapoints.Video, datapoint_wrapper=False)
Philip Meier's avatar
Philip Meier committed
133
134
def get_size_video(video: torch.Tensor) -> List[int]:
    return get_size_image_tensor(video)
135
136


137
@_register_kernel_internal(get_size, datapoints.Mask, datapoint_wrapper=False)
Philip Meier's avatar
Philip Meier committed
138
139
def get_size_mask(mask: torch.Tensor) -> List[int]:
    return get_size_image_tensor(mask)
140
141


142
@_register_kernel_internal(get_size, datapoints.BoundingBoxes, datapoint_wrapper=False)
Philip Meier's avatar
Philip Meier committed
143
144
def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]:
    return list(bounding_box.canvas_size)
145
146


147
@_register_unsupported_type(PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask)
Philip Meier's avatar
Philip Meier committed
148
def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int:
149
150
151
    if not torch.jit.is_scripting():
        _log_api_usage_once(get_num_frames)

152
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
153
        return get_num_frames_video(inpt)
154
155
156
    elif isinstance(inpt, datapoints.Datapoint):
        kernel = _get_kernel(get_num_frames, type(inpt))
        return kernel(inpt)
157
    else:
158
159
160
161
162
163
164
165
166
        raise TypeError(
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
            f"but got {type(inpt)} instead."
        )


@_register_kernel_internal(get_num_frames, datapoints.Video, datapoint_wrapper=False)
def get_num_frames_video(video: torch.Tensor) -> int:
    return video.shape[-4]
167
168


169
170
def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor:
    xyxy = xywh if inplace else xywh.clone()
171
172
173
174
    xyxy[..., 2:] += xyxy[..., :2]
    return xyxy


175
176
def _xyxy_to_xywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
    xywh = xyxy if inplace else xyxy.clone()
177
178
179
180
    xywh[..., 2:] -= xywh[..., :2]
    return xywh


181
182
183
def _cxcywh_to_xyxy(cxcywh: torch.Tensor, inplace: bool) -> torch.Tensor:
    if not inplace:
        cxcywh = cxcywh.clone()
184

185
186
187
188
189
190
191
    # Trick to do fast division by 2 and ceil, without casting. It produces the same result as
    # `torchvision.ops._box_convert._box_cxcywh_to_xyxy`.
    half_wh = cxcywh[..., 2:].div(-2, rounding_mode=None if cxcywh.is_floating_point() else "floor").abs_()
    # (cx - width / 2) = x1, same for y1
    cxcywh[..., :2].sub_(half_wh)
    # (x1 + width) = x2, same for y2
    cxcywh[..., 2:].add_(cxcywh[..., :2])
192

193
194
195
196
197
198
199
200
201
202
203
204
205
    return cxcywh


def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
    if not inplace:
        xyxy = xyxy.clone()

    # (x2 - x1) = width, same for height
    xyxy[..., 2:].sub_(xyxy[..., :2])
    # (x1 * 2 + width) / 2 = x1 + width / 2 = x1 + (x2-x1)/2 = (x1 + x2)/2 = cx, same for cy
    xyxy[..., :2].mul_(2).add_(xyxy[..., 2:]).div_(2, rounding_mode=None if xyxy.is_floating_point() else "floor")

    return xyxy
206
207


208
209
def _convert_format_bounding_boxes(
    bounding_boxes: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False
210
) -> torch.Tensor:
211

212
    if new_format == old_format:
213
        return bounding_boxes
214

215
    # TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance
216
    if old_format == BoundingBoxFormat.XYWH:
217
        bounding_boxes = _xywh_to_xyxy(bounding_boxes, inplace)
218
    elif old_format == BoundingBoxFormat.CXCYWH:
219
        bounding_boxes = _cxcywh_to_xyxy(bounding_boxes, inplace)
220
221

    if new_format == BoundingBoxFormat.XYWH:
222
        bounding_boxes = _xyxy_to_xywh(bounding_boxes, inplace)
223
    elif new_format == BoundingBoxFormat.CXCYWH:
224
        bounding_boxes = _xyxy_to_cxcywh(bounding_boxes, inplace)
225

226
    return bounding_boxes
227
228


229
def convert_format_bounding_boxes(
Philip Meier's avatar
Philip Meier committed
230
    inpt: datapoints._InputTypeJIT,
231
232
233
    old_format: Optional[BoundingBoxFormat] = None,
    new_format: Optional[BoundingBoxFormat] = None,
    inplace: bool = False,
Philip Meier's avatar
Philip Meier committed
234
) -> datapoints._InputTypeJIT:
235
    # This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
236
    # inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
237
238
239
    # `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
    # default error that would be thrown if `new_format` had no default value.
    if new_format is None:
240
        raise TypeError("convert_format_bounding_boxes() missing 1 required argument: 'new_format'")
241
242

    if not torch.jit.is_scripting():
243
        _log_api_usage_once(convert_format_bounding_boxes)
244
245
246
247

    if torch.jit.is_scripting() or is_simple_tensor(inpt):
        if old_format is None:
            raise ValueError("For simple tensor inputs, `old_format` has to be passed.")
248
249
        return _convert_format_bounding_boxes(inpt, old_format=old_format, new_format=new_format, inplace=inplace)
    elif isinstance(inpt, datapoints.BoundingBoxes):
250
251
        if old_format is not None:
            raise ValueError("For bounding box datapoint inputs, `old_format` must not be passed.")
252
        output = _convert_format_bounding_boxes(
253
254
            inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace
        )
255
        return datapoints.BoundingBoxes.wrap_like(inpt, output, format=new_format)
256
257
258
259
260
261
    else:
        raise TypeError(
            f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
        )


262
def _clamp_bounding_boxes(
Philip Meier's avatar
Philip Meier committed
263
    bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: Tuple[int, int]
264
) -> torch.Tensor:
265
266
    # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
    #  BoundingBoxFormat instead of converting back and forth
267
268
269
270
    in_dtype = bounding_boxes.dtype
    bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
    xyxy_boxes = convert_format_bounding_boxes(
        bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
271
    )
Philip Meier's avatar
Philip Meier committed
272
273
    xyxy_boxes[..., 0::2].clamp_(min=0, max=canvas_size[1])
    xyxy_boxes[..., 1::2].clamp_(min=0, max=canvas_size[0])
274
    out_boxes = convert_format_bounding_boxes(
275
276
277
        xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True
    )
    return out_boxes.to(in_dtype)
278
279


280
def clamp_bounding_boxes(
Philip Meier's avatar
Philip Meier committed
281
    inpt: datapoints._InputTypeJIT,
282
    format: Optional[BoundingBoxFormat] = None,
Philip Meier's avatar
Philip Meier committed
283
    canvas_size: Optional[Tuple[int, int]] = None,
Philip Meier's avatar
Philip Meier committed
284
) -> datapoints._InputTypeJIT:
285
    if not torch.jit.is_scripting():
286
        _log_api_usage_once(clamp_bounding_boxes)
287
288

    if torch.jit.is_scripting() or is_simple_tensor(inpt):
Philip Meier's avatar
Philip Meier committed
289
290
291
292

        if format is None or canvas_size is None:
            raise ValueError("For simple tensor inputs, `format` and `canvas_size` has to be passed.")
        return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size)
293
    elif isinstance(inpt, datapoints.BoundingBoxes):
Philip Meier's avatar
Philip Meier committed
294
295
296
        if format is not None or canvas_size is not None:
            raise ValueError("For bounding box datapoint inputs, `format` and `canvas_size` must not be passed.")
        output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size)
297
        return datapoints.BoundingBoxes.wrap_like(inpt, output)
298
299
300
301
    else:
        raise TypeError(
            f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
        )