_meta.py 14.4 KB
Newer Older
1
from typing import List, Optional, Tuple, Union
2

3
4
import PIL.Image
import torch
5
6
from torchvision import datapoints
from torchvision.datapoints import BoundingBoxFormat
7
8
from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _max_value
9

10
11
from torchvision.utils import _log_api_usage_once

12
13
from ._utils import is_simple_tensor

14

15
16
17
18
19
20
21
22
23
24
25
26
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
    chw = list(image.shape[-3:])
    ndims = len(chw)
    if ndims == 3:
        return chw
    elif ndims == 2:
        chw.insert(0, 1)
        return chw
    else:
        raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")


27
28
29
get_dimensions_image_pil = _FP.get_dimensions


Philip Meier's avatar
Philip Meier committed
30
def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]:
31
32
33
    if not torch.jit.is_scripting():
        _log_api_usage_once(get_dimensions)

34
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
35
        return get_dimensions_image_tensor(inpt)
36
    elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
37
38
        channels = inpt.num_channels
        height, width = inpt.spatial_size
39
        return [channels, height, width]
40
41
    elif isinstance(inpt, PIL.Image.Image):
        return get_dimensions_image_pil(inpt)
42
    else:
43
        raise TypeError(
44
            f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
45
46
            f"but got {type(inpt)} instead."
        )
47
48


49
50
51
52
53
54
55
56
57
58
59
def get_num_channels_image_tensor(image: torch.Tensor) -> int:
    chw = image.shape[-3:]
    ndims = len(chw)
    if ndims == 3:
        return chw[0]
    elif ndims == 2:
        return 1
    else:
        raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")


60
get_num_channels_image_pil = _FP.get_image_num_channels
61
62


63
64
65
66
def get_num_channels_video(video: torch.Tensor) -> int:
    return get_num_channels_image_tensor(video)


Philip Meier's avatar
Philip Meier committed
67
def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int:
68
69
70
    if not torch.jit.is_scripting():
        _log_api_usage_once(get_num_channels)

71
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
72
        return get_num_channels_image_tensor(inpt)
73
    elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
74
75
76
        return inpt.num_channels
    elif isinstance(inpt, PIL.Image.Image):
        return get_num_channels_image_pil(inpt)
77
    else:
78
        raise TypeError(
79
            f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
80
81
            f"but got {type(inpt)} instead."
        )
82
83


84
85
86
87
88
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
# deprecating the old names.
get_image_num_channels = get_num_channels


89
def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]:
90
91
92
93
94
95
    hw = list(image.shape[-2:])
    ndims = len(hw)
    if ndims == 2:
        return hw
    else:
        raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
96
97
98
99
100
101
102
103


@torch.jit.unused
def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]:
    width, height = _FP.get_image_size(image)
    return [height, width]


104
105
106
107
108
109
110
111
112
def get_spatial_size_video(video: torch.Tensor) -> List[int]:
    return get_spatial_size_image_tensor(video)


def get_spatial_size_mask(mask: torch.Tensor) -> List[int]:
    return get_spatial_size_image_tensor(mask)


@torch.jit.unused
113
def get_spatial_size_bounding_box(bounding_box: datapoints.BoundingBox) -> List[int]:
114
    return list(bounding_box.spatial_size)
115
116


Philip Meier's avatar
Philip Meier committed
117
def get_spatial_size(inpt: datapoints._InputTypeJIT) -> List[int]:
118
119
120
    if not torch.jit.is_scripting():
        _log_api_usage_once(get_spatial_size)

121
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
122
        return get_spatial_size_image_tensor(inpt)
123
    elif isinstance(inpt, (datapoints.Image, datapoints.Video, datapoints.BoundingBox, datapoints.Mask)):
124
        return list(inpt.spatial_size)
125
    elif isinstance(inpt, PIL.Image.Image):
126
        return get_spatial_size_image_pil(inpt)
127
128
    else:
        raise TypeError(
129
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
130
131
            f"but got {type(inpt)} instead."
        )
132
133
134
135
136
137


def get_num_frames_video(video: torch.Tensor) -> int:
    return video.shape[-4]


Philip Meier's avatar
Philip Meier committed
138
def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int:
139
140
141
    if not torch.jit.is_scripting():
        _log_api_usage_once(get_num_frames)

142
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
143
        return get_num_frames_video(inpt)
144
    elif isinstance(inpt, datapoints.Video):
145
        return inpt.num_frames
146
    else:
147
        raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.")
148
149


150
151
def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor:
    xyxy = xywh if inplace else xywh.clone()
152
153
154
155
    xyxy[..., 2:] += xyxy[..., :2]
    return xyxy


156
157
def _xyxy_to_xywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
    xywh = xyxy if inplace else xyxy.clone()
158
159
160
161
    xywh[..., 2:] -= xywh[..., :2]
    return xywh


162
163
164
def _cxcywh_to_xyxy(cxcywh: torch.Tensor, inplace: bool) -> torch.Tensor:
    if not inplace:
        cxcywh = cxcywh.clone()
165

166
167
168
169
170
171
172
    # Trick to do fast division by 2 and ceil, without casting. It produces the same result as
    # `torchvision.ops._box_convert._box_cxcywh_to_xyxy`.
    half_wh = cxcywh[..., 2:].div(-2, rounding_mode=None if cxcywh.is_floating_point() else "floor").abs_()
    # (cx - width / 2) = x1, same for y1
    cxcywh[..., :2].sub_(half_wh)
    # (x1 + width) = x2, same for y2
    cxcywh[..., 2:].add_(cxcywh[..., :2])
173

174
175
176
177
178
179
180
181
182
183
184
185
186
    return cxcywh


def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
    if not inplace:
        xyxy = xyxy.clone()

    # (x2 - x1) = width, same for height
    xyxy[..., 2:].sub_(xyxy[..., :2])
    # (x1 * 2 + width) / 2 = x1 + width / 2 = x1 + (x2-x1)/2 = (x1 + x2)/2 = cx, same for cy
    xyxy[..., :2].mul_(2).add_(xyxy[..., 2:]).div_(2, rounding_mode=None if xyxy.is_floating_point() else "floor")

    return xyxy
187
188


189
def _convert_format_bounding_box(
190
    bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False
191
) -> torch.Tensor:
192

193
    if new_format == old_format:
194
        return bounding_box
195

196
    # TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance
197
    if old_format == BoundingBoxFormat.XYWH:
198
        bounding_box = _xywh_to_xyxy(bounding_box, inplace)
199
    elif old_format == BoundingBoxFormat.CXCYWH:
200
        bounding_box = _cxcywh_to_xyxy(bounding_box, inplace)
201
202

    if new_format == BoundingBoxFormat.XYWH:
203
        bounding_box = _xyxy_to_xywh(bounding_box, inplace)
204
    elif new_format == BoundingBoxFormat.CXCYWH:
205
        bounding_box = _xyxy_to_cxcywh(bounding_box, inplace)
206
207
208
209

    return bounding_box


210
def convert_format_bounding_box(
Philip Meier's avatar
Philip Meier committed
211
    inpt: datapoints._InputTypeJIT,
212
213
214
    old_format: Optional[BoundingBoxFormat] = None,
    new_format: Optional[BoundingBoxFormat] = None,
    inplace: bool = False,
Philip Meier's avatar
Philip Meier committed
215
) -> datapoints._InputTypeJIT:
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    # This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
    # inputs as well as extract it from `datapoints.BoundingBox` inputs. However, putting a default value on
    # `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
    # default error that would be thrown if `new_format` had no default value.
    if new_format is None:
        raise TypeError("convert_format_bounding_box() missing 1 required argument: 'new_format'")

    if not torch.jit.is_scripting():
        _log_api_usage_once(convert_format_bounding_box)

    if torch.jit.is_scripting() or is_simple_tensor(inpt):
        if old_format is None:
            raise ValueError("For simple tensor inputs, `old_format` has to be passed.")
        return _convert_format_bounding_box(inpt, old_format=old_format, new_format=new_format, inplace=inplace)
    elif isinstance(inpt, datapoints.BoundingBox):
        if old_format is not None:
            raise ValueError("For bounding box datapoint inputs, `old_format` must not be passed.")
233
234
235
236
        output = _convert_format_bounding_box(
            inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace
        )
        return datapoints.BoundingBox.wrap_like(inpt, output, format=new_format)
237
238
239
240
241
242
    else:
        raise TypeError(
            f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
        )


243
def _clamp_bounding_box(
244
    bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int]
245
) -> torch.Tensor:
246
247
    # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
    #  BoundingBoxFormat instead of converting back and forth
248
249
    in_dtype = bounding_box.dtype
    bounding_box = bounding_box.clone() if bounding_box.is_floating_point() else bounding_box.float()
250
    xyxy_boxes = convert_format_bounding_box(
251
        bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
252
    )
253
254
    xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1])
    xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0])
255
256
257
258
    out_boxes = convert_format_bounding_box(
        xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True
    )
    return out_boxes.to(in_dtype)
259
260


261
def clamp_bounding_box(
Philip Meier's avatar
Philip Meier committed
262
    inpt: datapoints._InputTypeJIT,
263
264
    format: Optional[BoundingBoxFormat] = None,
    spatial_size: Optional[Tuple[int, int]] = None,
Philip Meier's avatar
Philip Meier committed
265
) -> datapoints._InputTypeJIT:
266
267
268
269
270
271
272
273
274
275
    if not torch.jit.is_scripting():
        _log_api_usage_once(clamp_bounding_box)

    if torch.jit.is_scripting() or is_simple_tensor(inpt):
        if format is None or spatial_size is None:
            raise ValueError("For simple tensor inputs, `format` and `spatial_size` has to be passed.")
        return _clamp_bounding_box(inpt, format=format, spatial_size=spatial_size)
    elif isinstance(inpt, datapoints.BoundingBox):
        if format is not None or spatial_size is not None:
            raise ValueError("For bounding box datapoint inputs, `format` and `spatial_size` must not be passed.")
276
        output = _clamp_bounding_box(inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size)
277
278
279
280
281
282
283
        return datapoints.BoundingBox.wrap_like(inpt, output)
    else:
        raise TypeError(
            f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
        )


284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
def _num_value_bits(dtype: torch.dtype) -> int:
    if dtype == torch.uint8:
        return 8
    elif dtype == torch.int8:
        return 7
    elif dtype == torch.int16:
        return 15
    elif dtype == torch.int32:
        return 31
    elif dtype == torch.int64:
        return 63
    else:
        raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")


def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
    if image.dtype == dtype:
        return image

    float_input = image.is_floating_point()
    if torch.jit.is_scripting():
        # TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
        float_output = torch.tensor(0, dtype=dtype).is_floating_point()
    else:
        float_output = dtype.is_floating_point

    if float_input:
        # float to float
        if float_output:
            return image.to(dtype)

        # float to int
        if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
            image.dtype == torch.float64 and dtype == torch.int64
        ):
            raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")

        # For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
        # to the integer dtype  is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
        # be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
        # for a detailed analysis.
        # To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
        # Instead, we can also multiply by the maximum value plus something close to `1`. See
        # https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
        eps = 1e-3
329
        max_value = float(_max_value(dtype))
330
331
332
333
334
335
        # We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
        # discrete set `{0, 1}`.
        return image.mul(max_value + 1.0 - eps).to(dtype)
    else:
        # int to float
        if float_output:
336
            return image.to(dtype).mul_(1.0 / _max_value(image.dtype))
337
338
339
340
341
342
343
344

        # int to int
        num_value_bits_input = _num_value_bits(image.dtype)
        num_value_bits_output = _num_value_bits(dtype)

        if num_value_bits_input > num_value_bits_output:
            return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
        else:
345
            return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)
346
347
348
349
350
351
352
353
354
355
356
357


# We changed the name to align it with the new naming scheme. Still, `convert_image_dtype` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
convert_image_dtype = convert_dtype_image_tensor


def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
    return convert_dtype_image_tensor(video, dtype)


def convert_dtype(
Philip Meier's avatar
Philip Meier committed
358
    inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], dtype: torch.dtype = torch.float
359
) -> torch.Tensor:
360
361
362
    if not torch.jit.is_scripting():
        _log_api_usage_once(convert_dtype)

363
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
364
        return convert_dtype_image_tensor(inpt, dtype)
365
    elif isinstance(inpt, datapoints.Image):
366
        output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype)
367
368
        return datapoints.Image.wrap_like(inpt, output)
    elif isinstance(inpt, datapoints.Video):
369
        output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype)
370
        return datapoints.Video.wrap_like(inpt, output)
371
372
    else:
        raise TypeError(
373
            f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead."
374
        )