_bounding_box.py 7.07 KB
Newer Older
1
2
from __future__ import annotations

3
from typing import Any, List, Optional, Sequence, Tuple, Union
Philip Meier's avatar
Philip Meier committed
4
5

import torch
Philip Meier's avatar
Philip Meier committed
6
from torchvision._utils import StrEnum
vfdev's avatar
vfdev committed
7
from torchvision.transforms import InterpolationMode  # TODO: this needs to be moved out of transforms
Philip Meier's avatar
Philip Meier committed
8

Philip Meier's avatar
Philip Meier committed
9
from ._datapoint import _FillTypeJIT, Datapoint
Philip Meier's avatar
Philip Meier committed
10
11
12


class BoundingBoxFormat(StrEnum):
13
14
15
    XYXY = StrEnum.auto()
    XYWH = StrEnum.auto()
    CXCYWH = StrEnum.auto()
Philip Meier's avatar
Philip Meier committed
16
17


18
class BoundingBox(Datapoint):
Philip Meier's avatar
Philip Meier committed
19
    format: BoundingBoxFormat
20
    spatial_size: Tuple[int, int]
Philip Meier's avatar
Philip Meier committed
21

22
    @classmethod
23
    def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, spatial_size: Tuple[int, int]) -> BoundingBox:
24
25
        bounding_box = tensor.as_subclass(cls)
        bounding_box.format = format
26
        bounding_box.spatial_size = spatial_size
27
28
        return bounding_box

29
    def __new__(
Philip Meier's avatar
Philip Meier committed
30
        cls,
31
32
33
        data: Any,
        *,
        format: Union[BoundingBoxFormat, str],
34
        spatial_size: Tuple[int, int],
35
36
        dtype: Optional[torch.dtype] = None,
        device: Optional[Union[torch.device, str, int]] = None,
37
        requires_grad: Optional[bool] = None,
38
    ) -> BoundingBox:
39
        tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
40

Philip Meier's avatar
Philip Meier committed
41
        if isinstance(format, str):
Philip Meier's avatar
Philip Meier committed
42
            format = BoundingBoxFormat.from_str(format.upper())
Philip Meier's avatar
Philip Meier committed
43

44
        return cls._wrap(tensor, format=format, spatial_size=spatial_size)
45

46
    @classmethod
47
    def wrap_like(
48
49
        cls,
        other: BoundingBox,
50
        tensor: torch.Tensor,
51
        *,
52
        format: Optional[BoundingBoxFormat] = None,
53
        spatial_size: Optional[Tuple[int, int]] = None,
54
    ) -> BoundingBox:
55
56
        return cls._wrap(
            tensor,
57
            format=format if format is not None else other.format,
58
            spatial_size=spatial_size if spatial_size is not None else other.spatial_size,
59
60
        )

61
    def __repr__(self, *, tensor_contents: Any = None) -> str:  # type: ignore[override]
62
        return self._make_repr(format=self.format, spatial_size=self.spatial_size)
63

64
    def horizontal_flip(self) -> BoundingBox:
65
66
67
        output = self._F.horizontal_flip_bounding_box(
            self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
        )
68
        return BoundingBox.wrap_like(self, output)
69
70

    def vertical_flip(self) -> BoundingBox:
71
72
73
        output = self._F.vertical_flip_bounding_box(
            self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
        )
74
        return BoundingBox.wrap_like(self, output)
75
76
77
78

    def resize(  # type: ignore[override]
        self,
        size: List[int],
79
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
80
        max_size: Optional[int] = None,
81
        antialias: Optional[Union[str, bool]] = "warn",
82
    ) -> BoundingBox:
83
        output, spatial_size = self._F.resize_bounding_box(
84
85
86
87
            self.as_subclass(torch.Tensor),
            spatial_size=self.spatial_size,
            size=size,
            max_size=max_size,
88
89
        )
        return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
90
91

    def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
92
        output, spatial_size = self._F.crop_bounding_box(
93
            self.as_subclass(torch.Tensor), self.format, top=top, left=left, height=height, width=width
94
        )
95
        return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
96
97

    def center_crop(self, output_size: List[int]) -> BoundingBox:
98
        output, spatial_size = self._F.center_crop_bounding_box(
99
            self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size, output_size=output_size
100
        )
101
        return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
102
103
104
105
106
107
108
109

    def resized_crop(
        self,
        top: int,
        left: int,
        height: int,
        width: int,
        size: List[int],
110
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
111
        antialias: Optional[Union[str, bool]] = "warn",
112
    ) -> BoundingBox:
113
114
115
        output, spatial_size = self._F.resized_crop_bounding_box(
            self.as_subclass(torch.Tensor), self.format, top, left, height, width, size=size
        )
116
        return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
117
118

    def pad(
119
120
        self,
        padding: Union[int, Sequence[int]],
121
        fill: Optional[Union[int, float, List[float]]] = None,
122
        padding_mode: str = "constant",
123
    ) -> BoundingBox:
124
        output, spatial_size = self._F.pad_bounding_box(
125
126
127
128
129
            self.as_subclass(torch.Tensor),
            format=self.format,
            spatial_size=self.spatial_size,
            padding=padding,
            padding_mode=padding_mode,
130
        )
131
        return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
132
133
134
135

    def rotate(
        self,
        angle: float,
136
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
137
138
        expand: bool = False,
        center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
139
        fill: _FillTypeJIT = None,
140
    ) -> BoundingBox:
141
        output, spatial_size = self._F.rotate_bounding_box(
142
143
144
145
146
147
            self.as_subclass(torch.Tensor),
            format=self.format,
            spatial_size=self.spatial_size,
            angle=angle,
            expand=expand,
            center=center,
148
        )
149
        return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
150
151
152

    def affine(
        self,
153
        angle: Union[int, float],
154
155
156
        translate: List[float],
        scale: float,
        shear: List[float],
157
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
158
        fill: _FillTypeJIT = None,
159
160
        center: Optional[List[float]] = None,
    ) -> BoundingBox:
161
        output = self._F.affine_bounding_box(
162
            self.as_subclass(torch.Tensor),
163
            self.format,
164
            self.spatial_size,
165
166
167
168
169
170
            angle,
            translate=translate,
            scale=scale,
            shear=shear,
            center=center,
        )
171
        return BoundingBox.wrap_like(self, output)
172
173
174

    def perspective(
        self,
175
176
        startpoints: Optional[List[List[int]]],
        endpoints: Optional[List[List[int]]],
177
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
178
        fill: _FillTypeJIT = None,
179
        coefficients: Optional[List[float]] = None,
180
    ) -> BoundingBox:
181
        output = self._F.perspective_bounding_box(
182
183
            self.as_subclass(torch.Tensor),
            format=self.format,
184
            spatial_size=self.spatial_size,
185
186
187
            startpoints=startpoints,
            endpoints=endpoints,
            coefficients=coefficients,
188
        )
189
        return BoundingBox.wrap_like(self, output)
190
191
192
193

    def elastic(
        self,
        displacement: torch.Tensor,
194
        interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Philip Meier's avatar
Philip Meier committed
195
        fill: _FillTypeJIT = None,
196
    ) -> BoundingBox:
197
198
199
        output = self._F.elastic_bounding_box(
            self.as_subclass(torch.Tensor), self.format, self.spatial_size, displacement=displacement
        )
200
        return BoundingBox.wrap_like(self, output)