_bounding_box.py 5.39 KB
Newer Older
1
2
from __future__ import annotations

3
from enum import Enum
4
from typing import Any, Mapping, Optional, Sequence, Tuple, Union
Philip Meier's avatar
Philip Meier committed
5
6

import torch
7
from torch.utils._pytree import tree_flatten
Philip Meier's avatar
Philip Meier committed
8

9
from ._datapoint import Datapoint
Philip Meier's avatar
Philip Meier committed
10
11


12
class BoundingBoxFormat(Enum):
Philip Meier's avatar
Philip Meier committed
13
14
15
16
17
18
19
20
21
    """[BETA] Coordinate format of a bounding box.

    Available formats are

    * ``XYXY``
    * ``XYWH``
    * ``CXCYWH``
    """

22
23
24
    XYXY = "XYXY"
    XYWH = "XYWH"
    CXCYWH = "CXCYWH"
Philip Meier's avatar
Philip Meier committed
25
26


27
class BoundingBoxes(Datapoint):
Philip Meier's avatar
Philip Meier committed
28
29
    """[BETA] :class:`torch.Tensor` subclass for bounding boxes.

30
31
32
33
34
35
    .. note::
        There should be only one :class:`~torchvision.datapoints.BoundingBoxes`
        instance per sample e.g. ``{"img": img, "bbox": BoundingBoxes(...)}``,
        although one :class:`~torchvision.datapoints.BoundingBoxes` object can
        contain multiple bounding boxes.

Philip Meier's avatar
Philip Meier committed
36
37
38
    Args:
        data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
        format (BoundingBoxFormat, str): Format of the bounding box.
Philip Meier's avatar
Philip Meier committed
39
        canvas_size (two-tuple of ints): Height and width of the corresponding image or video.
Philip Meier's avatar
Philip Meier committed
40
41
42
43
44
45
46
47
        dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
            ``data``.
        device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
            :class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
        requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
            ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
    """

Philip Meier's avatar
Philip Meier committed
48
    format: BoundingBoxFormat
Philip Meier's avatar
Philip Meier committed
49
    canvas_size: Tuple[int, int]
Philip Meier's avatar
Philip Meier committed
50

51
    @classmethod
52
53
54
55
56
57
    def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int], check_dims: bool = True) -> BoundingBoxes:  # type: ignore[override]
        if check_dims:
            if tensor.ndim == 1:
                tensor = tensor.unsqueeze(0)
            elif tensor.ndim != 2:
                raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D")
58
59
        if isinstance(format, str):
            format = BoundingBoxFormat[format.upper()]
60
61
        bounding_boxes = tensor.as_subclass(cls)
        bounding_boxes.format = format
Philip Meier's avatar
Philip Meier committed
62
        bounding_boxes.canvas_size = canvas_size
63
        return bounding_boxes
64

65
    def __new__(
Philip Meier's avatar
Philip Meier committed
66
        cls,
67
68
69
        data: Any,
        *,
        format: Union[BoundingBoxFormat, str],
Philip Meier's avatar
Philip Meier committed
70
        canvas_size: Tuple[int, int],
71
72
        dtype: Optional[torch.dtype] = None,
        device: Optional[Union[torch.device, str, int]] = None,
73
        requires_grad: Optional[bool] = None,
74
    ) -> BoundingBoxes:
75
        tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
Philip Meier's avatar
Philip Meier committed
76
        return cls._wrap(tensor, format=format, canvas_size=canvas_size)
77

78
    @classmethod
79
    def wrap_like(
80
        cls,
81
        other: BoundingBoxes,
82
        tensor: torch.Tensor,
83
        *,
84
        format: Optional[Union[BoundingBoxFormat, str]] = None,
Philip Meier's avatar
Philip Meier committed
85
        canvas_size: Optional[Tuple[int, int]] = None,
86
87
    ) -> BoundingBoxes:
        """Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference.
Philip Meier's avatar
Philip Meier committed
88
89

        Args:
90
91
            other (BoundingBoxes): Reference bounding box.
            tensor (Tensor): Tensor to be wrapped as :class:`BoundingBoxes`
Philip Meier's avatar
Philip Meier committed
92
93
            format (BoundingBoxFormat, str, optional): Format of the bounding box.  If omitted, it is taken from the
                reference.
Philip Meier's avatar
Philip Meier committed
94
            canvas_size (two-tuple of ints, optional): Height and width of the corresponding image or video. If
Philip Meier's avatar
Philip Meier committed
95
96
97
                omitted, it is taken from the reference.

        """
98
99
        return cls._wrap(
            tensor,
100
            format=format if format is not None else other.format,
Philip Meier's avatar
Philip Meier committed
101
            canvas_size=canvas_size if canvas_size is not None else other.canvas_size,
102
103
        )

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    @classmethod
    def _wrap_output(
        cls,
        output: torch.Tensor,
        args: Sequence[Any] = (),
        kwargs: Optional[Mapping[str, Any]] = None,
    ) -> BoundingBoxes:
        # If there are BoundingBoxes instances in the output, their metadata got lost when we called
        # super().__torch_function__. We need to restore the metadata somehow, so we choose to take
        # the metadata from the first bbox in the parameters.
        # This should be what we want in most cases. When it's not, it's probably a mis-use anyway, e.g.
        # something like some_xyxy_bbox + some_xywh_bbox; we don't guard against those cases.
        flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ()))  # type: ignore[operator]
        first_bbox_from_args = next(x for x in flat_params if isinstance(x, BoundingBoxes))
        format, canvas_size = first_bbox_from_args.format, first_bbox_from_args.canvas_size

        if isinstance(output, torch.Tensor) and not isinstance(output, BoundingBoxes):
            output = BoundingBoxes._wrap(output, format=format, canvas_size=canvas_size, check_dims=False)
        elif isinstance(output, (tuple, list)):
            output = type(output)(
                BoundingBoxes._wrap(part, format=format, canvas_size=canvas_size, check_dims=False) for part in output
            )
        return output

128
    def __repr__(self, *, tensor_contents: Any = None) -> str:  # type: ignore[override]
Philip Meier's avatar
Philip Meier committed
129
        return self._make_repr(format=self.format, canvas_size=self.canvas_size)