_mask.py 5.54 KB
Newer Older
1
2
from __future__ import annotations

Philip Meier's avatar
Philip Meier committed
3
from typing import Any, List, Optional, Union
4

5
import PIL.Image
6
import torch
7
8
from torchvision.transforms import InterpolationMode

Philip Meier's avatar
Philip Meier committed
9
from ._datapoint import _FillTypeJIT, Datapoint
10
11


12
class Mask(Datapoint):
Philip Meier's avatar
Philip Meier committed
13
14
15
16
17
18
19
20
21
22
23
24
25
    """[BETA] :class:`torch.Tensor` subclass for segmentation and detection masks.

    Args:
        data (tensor-like, PIL.Image.Image): Any data that can be turned into a tensor with :func:`torch.as_tensor` as
            well as PIL images.
        dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
            ``data``.
        device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
            :class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
        requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
            ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
    """

26
27
28
29
30
31
32
33
34
35
    @classmethod
    def _wrap(cls, tensor: torch.Tensor) -> Mask:
        return tensor.as_subclass(cls)

    def __new__(
        cls,
        data: Any,
        *,
        dtype: Optional[torch.dtype] = None,
        device: Optional[Union[torch.device, str, int]] = None,
36
        requires_grad: Optional[bool] = None,
37
    ) -> Mask:
38
        if isinstance(data, PIL.Image.Image):
39
            from torchvision.transforms.v2 import functional as F
40
41
42

            data = F.pil_to_tensor(data)

43
44
45
46
47
48
49
50
51
52
53
        tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
        return cls._wrap(tensor)

    @classmethod
    def wrap_like(
        cls,
        other: Mask,
        tensor: torch.Tensor,
    ) -> Mask:
        return cls._wrap(tensor)

54
    def horizontal_flip(self) -> Mask:
55
        output = self._F.horizontal_flip_mask(self.as_subclass(torch.Tensor))
56
        return Mask.wrap_like(self, output)
57

58
    def vertical_flip(self) -> Mask:
59
        output = self._F.vertical_flip_mask(self.as_subclass(torch.Tensor))
60
        return Mask.wrap_like(self, output)
61
62
63
64

    def resize(  # type: ignore[override]
        self,
        size: List[int],
65
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
66
        max_size: Optional[int] = None,
67
        antialias: Optional[Union[str, bool]] = "warn",
68
    ) -> Mask:
69
        output = self._F.resize_mask(self.as_subclass(torch.Tensor), size, max_size=max_size)
70
        return Mask.wrap_like(self, output)
71

72
    def crop(self, top: int, left: int, height: int, width: int) -> Mask:
73
        output = self._F.crop_mask(self.as_subclass(torch.Tensor), top, left, height, width)
74
        return Mask.wrap_like(self, output)
75

76
    def center_crop(self, output_size: List[int]) -> Mask:
77
        output = self._F.center_crop_mask(self.as_subclass(torch.Tensor), output_size=output_size)
78
        return Mask.wrap_like(self, output)
79
80
81
82
83
84
85
86

    def resized_crop(
        self,
        top: int,
        left: int,
        height: int,
        width: int,
        size: List[int],
87
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
88
        antialias: Optional[Union[str, bool]] = "warn",
89
    ) -> Mask:
90
        output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size)
91
        return Mask.wrap_like(self, output)
92
93

    def pad(
94
        self,
95
96
        padding: List[int],
        fill: Optional[Union[int, float, List[float]]] = None,
97
        padding_mode: str = "constant",
98
    ) -> Mask:
99
        output = self._F.pad_mask(self.as_subclass(torch.Tensor), padding, padding_mode=padding_mode, fill=fill)
100
        return Mask.wrap_like(self, output)
101
102
103
104

    def rotate(
        self,
        angle: float,
105
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
106
107
        expand: bool = False,
        center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
108
        fill: _FillTypeJIT = None,
109
    ) -> Mask:
110
        output = self._F.rotate_mask(self.as_subclass(torch.Tensor), angle, expand=expand, center=center, fill=fill)
111
        return Mask.wrap_like(self, output)
112
113
114

    def affine(
        self,
115
        angle: Union[int, float],
116
117
118
        translate: List[float],
        scale: float,
        shear: List[float],
119
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
120
        fill: _FillTypeJIT = None,
121
        center: Optional[List[float]] = None,
122
123
    ) -> Mask:
        output = self._F.affine_mask(
124
            self.as_subclass(torch.Tensor),
125
126
127
128
            angle,
            translate=translate,
            scale=scale,
            shear=shear,
129
            fill=fill,
130
131
            center=center,
        )
132
        return Mask.wrap_like(self, output)
133
134
135

    def perspective(
        self,
136
137
        startpoints: Optional[List[List[int]]],
        endpoints: Optional[List[List[int]]],
138
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
139
        fill: _FillTypeJIT = None,
140
        coefficients: Optional[List[float]] = None,
141
    ) -> Mask:
142
143
144
        output = self._F.perspective_mask(
            self.as_subclass(torch.Tensor), startpoints, endpoints, fill=fill, coefficients=coefficients
        )
145
        return Mask.wrap_like(self, output)
146
147
148
149

    def elastic(
        self,
        displacement: torch.Tensor,
150
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
151
        fill: _FillTypeJIT = None,
152
    ) -> Mask:
153
        output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)
154
        return Mask.wrap_like(self, output)