_mask.py 5.67 KB
Newer Older
1
2
from __future__ import annotations

3
from typing import Any, List, Optional, Tuple, Union
4

5
import PIL.Image
6
import torch
7
8
from torchvision.transforms import InterpolationMode

Philip Meier's avatar
Philip Meier committed
9
from ._datapoint import _FillTypeJIT, Datapoint
10
11


12
class Mask(Datapoint):
Philip Meier's avatar
Philip Meier committed
13
14
15
16
17
18
19
20
21
22
23
24
25
    """[BETA] :class:`torch.Tensor` subclass for segmentation and detection masks.

    Args:
        data (tensor-like, PIL.Image.Image): Any data that can be turned into a tensor with :func:`torch.as_tensor` as
            well as PIL images.
        dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
            ``data``.
        device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
            :class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
        requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
            ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
    """

26
27
28
29
30
31
32
33
34
35
    @classmethod
    def _wrap(cls, tensor: torch.Tensor) -> Mask:
        return tensor.as_subclass(cls)

    def __new__(
        cls,
        data: Any,
        *,
        dtype: Optional[torch.dtype] = None,
        device: Optional[Union[torch.device, str, int]] = None,
36
        requires_grad: Optional[bool] = None,
37
    ) -> Mask:
38
        if isinstance(data, PIL.Image.Image):
39
            from torchvision.transforms.v2 import functional as F
40
41
42

            data = F.pil_to_tensor(data)

43
44
45
46
47
48
49
50
51
52
53
        tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
        return cls._wrap(tensor)

    @classmethod
    def wrap_like(
        cls,
        other: Mask,
        tensor: torch.Tensor,
    ) -> Mask:
        return cls._wrap(tensor)

54
    @property
55
    def spatial_size(self) -> Tuple[int, int]:
56
        return tuple(self.shape[-2:])  # type: ignore[return-value]
57

58
    def horizontal_flip(self) -> Mask:
59
        output = self._F.horizontal_flip_mask(self.as_subclass(torch.Tensor))
60
        return Mask.wrap_like(self, output)
61

62
    def vertical_flip(self) -> Mask:
63
        output = self._F.vertical_flip_mask(self.as_subclass(torch.Tensor))
64
        return Mask.wrap_like(self, output)
65
66
67
68

    def resize(  # type: ignore[override]
        self,
        size: List[int],
69
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
70
        max_size: Optional[int] = None,
71
        antialias: Optional[Union[str, bool]] = "warn",
72
    ) -> Mask:
73
        output = self._F.resize_mask(self.as_subclass(torch.Tensor), size, max_size=max_size)
74
        return Mask.wrap_like(self, output)
75

76
    def crop(self, top: int, left: int, height: int, width: int) -> Mask:
77
        output = self._F.crop_mask(self.as_subclass(torch.Tensor), top, left, height, width)
78
        return Mask.wrap_like(self, output)
79

80
    def center_crop(self, output_size: List[int]) -> Mask:
81
        output = self._F.center_crop_mask(self.as_subclass(torch.Tensor), output_size=output_size)
82
        return Mask.wrap_like(self, output)
83
84
85
86
87
88
89
90

    def resized_crop(
        self,
        top: int,
        left: int,
        height: int,
        width: int,
        size: List[int],
91
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
92
        antialias: Optional[Union[str, bool]] = "warn",
93
    ) -> Mask:
94
        output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size)
95
        return Mask.wrap_like(self, output)
96
97

    def pad(
98
        self,
99
100
        padding: List[int],
        fill: Optional[Union[int, float, List[float]]] = None,
101
        padding_mode: str = "constant",
102
    ) -> Mask:
103
        output = self._F.pad_mask(self.as_subclass(torch.Tensor), padding, padding_mode=padding_mode, fill=fill)
104
        return Mask.wrap_like(self, output)
105
106
107
108

    def rotate(
        self,
        angle: float,
109
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
110
111
        expand: bool = False,
        center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
112
        fill: _FillTypeJIT = None,
113
    ) -> Mask:
114
        output = self._F.rotate_mask(self.as_subclass(torch.Tensor), angle, expand=expand, center=center, fill=fill)
115
        return Mask.wrap_like(self, output)
116
117
118

    def affine(
        self,
119
        angle: Union[int, float],
120
121
122
        translate: List[float],
        scale: float,
        shear: List[float],
123
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
124
        fill: _FillTypeJIT = None,
125
        center: Optional[List[float]] = None,
126
127
    ) -> Mask:
        output = self._F.affine_mask(
128
            self.as_subclass(torch.Tensor),
129
130
131
132
            angle,
            translate=translate,
            scale=scale,
            shear=shear,
133
            fill=fill,
134
135
            center=center,
        )
136
        return Mask.wrap_like(self, output)
137
138
139

    def perspective(
        self,
140
141
        startpoints: Optional[List[List[int]]],
        endpoints: Optional[List[List[int]]],
142
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
143
        fill: _FillTypeJIT = None,
144
        coefficients: Optional[List[float]] = None,
145
    ) -> Mask:
146
147
148
        output = self._F.perspective_mask(
            self.as_subclass(torch.Tensor), startpoints, endpoints, fill=fill, coefficients=coefficients
        )
149
        return Mask.wrap_like(self, output)
150
151
152
153

    def elastic(
        self,
        displacement: torch.Tensor,
154
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
155
        fill: _FillTypeJIT = None,
156
    ) -> Mask:
157
        output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)
158
        return Mask.wrap_like(self, output)