_meta.py 14.5 KB
Newer Older
1
from typing import List, Optional, Tuple, Union
2

3
4
import PIL.Image
import torch
5
6
from torchvision import datapoints
from torchvision.datapoints import BoundingBoxFormat
7
8
from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _max_value
9

10
11
from torchvision.utils import _log_api_usage_once

12
13
from ._utils import is_simple_tensor

14

15
16
17
18
19
20
21
22
23
24
25
26
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
    chw = list(image.shape[-3:])
    ndims = len(chw)
    if ndims == 3:
        return chw
    elif ndims == 2:
        chw.insert(0, 1)
        return chw
    else:
        raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")


27
28
29
get_dimensions_image_pil = _FP.get_dimensions


Philip Meier's avatar
Philip Meier committed
30
def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]:
31
32
33
    if not torch.jit.is_scripting():
        _log_api_usage_once(get_dimensions)

34
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
35
        return get_dimensions_image_tensor(inpt)
36
    elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
37
38
        channels = inpt.num_channels
        height, width = inpt.spatial_size
39
        return [channels, height, width]
40
41
    elif isinstance(inpt, PIL.Image.Image):
        return get_dimensions_image_pil(inpt)
42
    else:
43
        raise TypeError(
44
            f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
45
46
            f"but got {type(inpt)} instead."
        )
47
48


49
50
51
52
53
54
55
56
57
58
59
def get_num_channels_image_tensor(image: torch.Tensor) -> int:
    chw = image.shape[-3:]
    ndims = len(chw)
    if ndims == 3:
        return chw[0]
    elif ndims == 2:
        return 1
    else:
        raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")


60
get_num_channels_image_pil = _FP.get_image_num_channels
61
62


63
64
65
66
def get_num_channels_video(video: torch.Tensor) -> int:
    return get_num_channels_image_tensor(video)


Philip Meier's avatar
Philip Meier committed
67
def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int:
68
69
70
    if not torch.jit.is_scripting():
        _log_api_usage_once(get_num_channels)

71
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
72
        return get_num_channels_image_tensor(inpt)
73
    elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
74
75
76
        return inpt.num_channels
    elif isinstance(inpt, PIL.Image.Image):
        return get_num_channels_image_pil(inpt)
77
    else:
78
        raise TypeError(
79
            f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
80
81
            f"but got {type(inpt)} instead."
        )
82
83


84
85
86
87
88
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
# deprecating the old names.
get_image_num_channels = get_num_channels


89
def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]:
90
91
92
93
94
95
    hw = list(image.shape[-2:])
    ndims = len(hw)
    if ndims == 2:
        return hw
    else:
        raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
96
97
98
99
100
101
102
103


@torch.jit.unused
def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]:
    width, height = _FP.get_image_size(image)
    return [height, width]


104
105
106
107
108
109
110
111
112
def get_spatial_size_video(video: torch.Tensor) -> List[int]:
    return get_spatial_size_image_tensor(video)


def get_spatial_size_mask(mask: torch.Tensor) -> List[int]:
    return get_spatial_size_image_tensor(mask)


@torch.jit.unused
113
def get_spatial_size_bounding_box(bounding_box: datapoints.BoundingBox) -> List[int]:
114
    return list(bounding_box.spatial_size)
115
116


Philip Meier's avatar
Philip Meier committed
117
def get_spatial_size(inpt: datapoints._InputTypeJIT) -> List[int]:
118
119
120
    if not torch.jit.is_scripting():
        _log_api_usage_once(get_spatial_size)

121
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
122
        return get_spatial_size_image_tensor(inpt)
123
    elif isinstance(inpt, (datapoints.Image, datapoints.Video, datapoints.BoundingBox, datapoints.Mask)):
124
        return list(inpt.spatial_size)
125
    elif isinstance(inpt, PIL.Image.Image):
126
        return get_spatial_size_image_pil(inpt)
127
128
    else:
        raise TypeError(
129
            f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
130
131
            f"but got {type(inpt)} instead."
        )
132
133
134
135
136
137


def get_num_frames_video(video: torch.Tensor) -> int:
    return video.shape[-4]


Philip Meier's avatar
Philip Meier committed
138
def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int:
139
140
141
    if not torch.jit.is_scripting():
        _log_api_usage_once(get_num_frames)

142
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
143
        return get_num_frames_video(inpt)
144
    elif isinstance(inpt, datapoints.Video):
145
        return inpt.num_frames
146
    else:
147
        raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.")
148
149


150
151
def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor:
    xyxy = xywh if inplace else xywh.clone()
152
153
154
155
    xyxy[..., 2:] += xyxy[..., :2]
    return xyxy


156
157
def _xyxy_to_xywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
    xywh = xyxy if inplace else xyxy.clone()
158
159
160
161
    xywh[..., 2:] -= xywh[..., :2]
    return xywh


162
163
164
def _cxcywh_to_xyxy(cxcywh: torch.Tensor, inplace: bool) -> torch.Tensor:
    if not inplace:
        cxcywh = cxcywh.clone()
165

166
167
168
169
170
171
172
    # Trick to do fast division by 2 and ceil, without casting. It produces the same result as
    # `torchvision.ops._box_convert._box_cxcywh_to_xyxy`.
    half_wh = cxcywh[..., 2:].div(-2, rounding_mode=None if cxcywh.is_floating_point() else "floor").abs_()
    # (cx - width / 2) = x1, same for y1
    cxcywh[..., :2].sub_(half_wh)
    # (x1 + width) = x2, same for y2
    cxcywh[..., 2:].add_(cxcywh[..., :2])
173

174
175
176
177
178
179
180
181
182
183
184
185
186
    return cxcywh


def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
    if not inplace:
        xyxy = xyxy.clone()

    # (x2 - x1) = width, same for height
    xyxy[..., 2:].sub_(xyxy[..., :2])
    # (x1 * 2 + width) / 2 = x1 + width / 2 = x1 + (x2-x1)/2 = (x1 + x2)/2 = cx, same for cy
    xyxy[..., :2].mul_(2).add_(xyxy[..., 2:]).div_(2, rounding_mode=None if xyxy.is_floating_point() else "floor")

    return xyxy
187
188


189
def _convert_format_bounding_box(
190
    bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False
191
) -> torch.Tensor:
192

193
    if new_format == old_format:
194
        return bounding_box
195

196
    # TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance
197
    if old_format == BoundingBoxFormat.XYWH:
198
        bounding_box = _xywh_to_xyxy(bounding_box, inplace)
199
    elif old_format == BoundingBoxFormat.CXCYWH:
200
        bounding_box = _cxcywh_to_xyxy(bounding_box, inplace)
201
202

    if new_format == BoundingBoxFormat.XYWH:
203
        bounding_box = _xyxy_to_xywh(bounding_box, inplace)
204
    elif new_format == BoundingBoxFormat.CXCYWH:
205
        bounding_box = _xyxy_to_cxcywh(bounding_box, inplace)
206
207
208
209

    return bounding_box


210
def convert_format_bounding_box(
Philip Meier's avatar
Philip Meier committed
211
    inpt: datapoints._InputTypeJIT,
212
213
214
    old_format: Optional[BoundingBoxFormat] = None,
    new_format: Optional[BoundingBoxFormat] = None,
    inplace: bool = False,
Philip Meier's avatar
Philip Meier committed
215
) -> datapoints._InputTypeJIT:
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    # This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
    # inputs as well as extract it from `datapoints.BoundingBox` inputs. However, putting a default value on
    # `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
    # default error that would be thrown if `new_format` had no default value.
    if new_format is None:
        raise TypeError("convert_format_bounding_box() missing 1 required argument: 'new_format'")

    if not torch.jit.is_scripting():
        _log_api_usage_once(convert_format_bounding_box)

    if torch.jit.is_scripting() or is_simple_tensor(inpt):
        if old_format is None:
            raise ValueError("For simple tensor inputs, `old_format` has to be passed.")
        return _convert_format_bounding_box(inpt, old_format=old_format, new_format=new_format, inplace=inplace)
    elif isinstance(inpt, datapoints.BoundingBox):
        if old_format is not None:
            raise ValueError("For bounding box datapoint inputs, `old_format` must not be passed.")
233
234
235
236
        output = _convert_format_bounding_box(
            inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace
        )
        return datapoints.BoundingBox.wrap_like(inpt, output, format=new_format)
237
238
239
240
241
242
    else:
        raise TypeError(
            f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
        )


243
def _clamp_bounding_box(
244
    bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int]
245
) -> torch.Tensor:
246
247
    # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
    #  BoundingBoxFormat instead of converting back and forth
248
249
    in_dtype = bounding_box.dtype
    bounding_box = bounding_box.clone() if bounding_box.is_floating_point() else bounding_box.float()
250
    xyxy_boxes = convert_format_bounding_box(
251
        bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
252
    )
253
254
    xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1])
    xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0])
255
256
257
258
    out_boxes = convert_format_bounding_box(
        xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True
    )
    return out_boxes.to(in_dtype)
259
260


261
def clamp_bounding_box(
Philip Meier's avatar
Philip Meier committed
262
    inpt: datapoints._InputTypeJIT,
263
264
    format: Optional[BoundingBoxFormat] = None,
    spatial_size: Optional[Tuple[int, int]] = None,
Philip Meier's avatar
Philip Meier committed
265
) -> datapoints._InputTypeJIT:
266
267
268
269
270
271
272
273
274
275
    if not torch.jit.is_scripting():
        _log_api_usage_once(clamp_bounding_box)

    if torch.jit.is_scripting() or is_simple_tensor(inpt):
        if format is None or spatial_size is None:
            raise ValueError("For simple tensor inputs, `format` and `spatial_size` has to be passed.")
        return _clamp_bounding_box(inpt, format=format, spatial_size=spatial_size)
    elif isinstance(inpt, datapoints.BoundingBox):
        if format is not None or spatial_size is not None:
            raise ValueError("For bounding box datapoint inputs, `format` and `spatial_size` must not be passed.")
276
        output = _clamp_bounding_box(inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size)
277
278
279
280
281
282
283
        return datapoints.BoundingBox.wrap_like(inpt, output)
    else:
        raise TypeError(
            f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
        )


284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def _num_value_bits(dtype: torch.dtype) -> int:
    if dtype == torch.uint8:
        return 8
    elif dtype == torch.int8:
        return 7
    elif dtype == torch.int16:
        return 15
    elif dtype == torch.int32:
        return 31
    elif dtype == torch.int64:
        return 63
    else:
        raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")


299
300
def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:

301
302
    if image.dtype == dtype:
        return image
303
304
    elif not scale:
        return image.to(dtype)
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331

    float_input = image.is_floating_point()
    if torch.jit.is_scripting():
        # TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
        float_output = torch.tensor(0, dtype=dtype).is_floating_point()
    else:
        float_output = dtype.is_floating_point

    if float_input:
        # float to float
        if float_output:
            return image.to(dtype)

        # float to int
        if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
            image.dtype == torch.float64 and dtype == torch.int64
        ):
            raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")

        # For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
        # to the integer dtype  is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
        # be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
        # for a detailed analysis.
        # To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
        # Instead, we can also multiply by the maximum value plus something close to `1`. See
        # https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
        eps = 1e-3
332
        max_value = float(_max_value(dtype))
333
334
335
336
337
338
        # We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
        # discrete set `{0, 1}`.
        return image.mul(max_value + 1.0 - eps).to(dtype)
    else:
        # int to float
        if float_output:
339
            return image.to(dtype).mul_(1.0 / _max_value(image.dtype))
340
341
342
343
344
345
346
347

        # int to int
        num_value_bits_input = _num_value_bits(image.dtype)
        num_value_bits_output = _num_value_bits(dtype)

        if num_value_bits_input > num_value_bits_output:
            return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
        else:
348
            return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)
349
350


351
352
353
# We encourage users to use to_dtype() instead but we keep this for BC
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
    return to_dtype_image_tensor(image, dtype=dtype, scale=True)
354
355


356
357
def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
    return to_dtype_image_tensor(video, dtype, scale=scale)
358
359


360
def to_dtype(inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
361
    if not torch.jit.is_scripting():
362
        _log_api_usage_once(to_dtype)
363

364
    if torch.jit.is_scripting() or is_simple_tensor(inpt):
365
        return to_dtype_image_tensor(inpt, dtype, scale=scale)
366
    elif isinstance(inpt, datapoints.Image):
367
        output = to_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
368
369
        return datapoints.Image.wrap_like(inpt, output)
    elif isinstance(inpt, datapoints.Video):
370
        output = to_dtype_video(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
371
        return datapoints.Video.wrap_like(inpt, output)
372
373
    elif isinstance(inpt, datapoints._datapoint.Datapoint):
        return inpt.to(dtype)
374
    else:
375
        raise TypeError(f"Input can either be a plain tensor or a datapoint, but got {type(inpt)} instead.")