_bounding_box.py 3.91 KB
Newer Older
1
2
from __future__ import annotations

3
from enum import Enum
4
from typing import Any, Optional, Tuple, Union
Philip Meier's avatar
Philip Meier committed
5
6
7

import torch

8
from ._datapoint import Datapoint
Philip Meier's avatar
Philip Meier committed
9
10


11
class BoundingBoxFormat(Enum):
Philip Meier's avatar
Philip Meier committed
12
13
14
15
16
17
18
19
20
    """[BETA] Coordinate format of a bounding box.

    Available formats are

    * ``XYXY``
    * ``XYWH``
    * ``CXCYWH``
    """

21
22
23
    XYXY = "XYXY"
    XYWH = "XYWH"
    CXCYWH = "CXCYWH"
Philip Meier's avatar
Philip Meier committed
24
25


26
class BoundingBoxes(Datapoint):
Philip Meier's avatar
Philip Meier committed
27
28
    """[BETA] :class:`torch.Tensor` subclass for bounding boxes.

29
30
31
32
33
34
    .. note::
        There should be only one :class:`~torchvision.datapoints.BoundingBoxes`
        instance per sample e.g. ``{"img": img, "bbox": BoundingBoxes(...)}``,
        although one :class:`~torchvision.datapoints.BoundingBoxes` object can
        contain multiple bounding boxes.

Philip Meier's avatar
Philip Meier committed
35
36
37
    Args:
        data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
        format (BoundingBoxFormat, str): Format of the bounding box.
Philip Meier's avatar
Philip Meier committed
38
        canvas_size (two-tuple of ints): Height and width of the corresponding image or video.
Philip Meier's avatar
Philip Meier committed
39
40
41
42
43
44
45
46
        dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
            ``data``.
        device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
            :class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
        requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
            ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
    """

Philip Meier's avatar
Philip Meier committed
47
    format: BoundingBoxFormat
Philip Meier's avatar
Philip Meier committed
48
    canvas_size: Tuple[int, int]
Philip Meier's avatar
Philip Meier committed
49

50
    @classmethod
51
    def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int]) -> BoundingBoxes:  # type: ignore[override]
52
53
54
55
        if tensor.ndim == 1:
            tensor = tensor.unsqueeze(0)
        elif tensor.ndim != 2:
            raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D")
56
57
        if isinstance(format, str):
            format = BoundingBoxFormat[format.upper()]
58
59
        bounding_boxes = tensor.as_subclass(cls)
        bounding_boxes.format = format
Philip Meier's avatar
Philip Meier committed
60
        bounding_boxes.canvas_size = canvas_size
61
        return bounding_boxes
62

63
    def __new__(
Philip Meier's avatar
Philip Meier committed
64
        cls,
65
66
67
        data: Any,
        *,
        format: Union[BoundingBoxFormat, str],
Philip Meier's avatar
Philip Meier committed
68
        canvas_size: Tuple[int, int],
69
70
        dtype: Optional[torch.dtype] = None,
        device: Optional[Union[torch.device, str, int]] = None,
71
        requires_grad: Optional[bool] = None,
72
    ) -> BoundingBoxes:
73
        tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
Philip Meier's avatar
Philip Meier committed
74
        return cls._wrap(tensor, format=format, canvas_size=canvas_size)
75

76
    @classmethod
77
    def wrap_like(
78
        cls,
79
        other: BoundingBoxes,
80
        tensor: torch.Tensor,
81
        *,
82
        format: Optional[Union[BoundingBoxFormat, str]] = None,
Philip Meier's avatar
Philip Meier committed
83
        canvas_size: Optional[Tuple[int, int]] = None,
84
85
    ) -> BoundingBoxes:
        """Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference.
Philip Meier's avatar
Philip Meier committed
86
87

        Args:
88
89
            other (BoundingBoxes): Reference bounding box.
            tensor (Tensor): Tensor to be wrapped as :class:`BoundingBoxes`
Philip Meier's avatar
Philip Meier committed
90
91
            format (BoundingBoxFormat, str, optional): Format of the bounding box.  If omitted, it is taken from the
                reference.
Philip Meier's avatar
Philip Meier committed
92
            canvas_size (two-tuple of ints, optional): Height and width of the corresponding image or video. If
Philip Meier's avatar
Philip Meier committed
93
94
95
                omitted, it is taken from the reference.

        """
96
97
        return cls._wrap(
            tensor,
98
            format=format if format is not None else other.format,
Philip Meier's avatar
Philip Meier committed
99
            canvas_size=canvas_size if canvas_size is not None else other.canvas_size,
100
101
        )

102
    def __repr__(self, *, tensor_contents: Any = None) -> str:  # type: ignore[override]
Philip Meier's avatar
Philip Meier committed
103
        return self._make_repr(format=self.format, canvas_size=self.canvas_size)