_mask.py 4.85 KB
Newer Older
1
2
from __future__ import annotations

3
from typing import Any, List, Optional, Tuple, Union
4

5
import PIL.Image
6
import torch
7
8
from torchvision.transforms import InterpolationMode

Philip Meier's avatar
Philip Meier committed
9
from ._datapoint import _FillTypeJIT, Datapoint
10
11


12
class Mask(Datapoint):
13
14
15
16
17
18
19
20
21
22
    @classmethod
    def _wrap(cls, tensor: torch.Tensor) -> Mask:
        return tensor.as_subclass(cls)

    def __new__(
        cls,
        data: Any,
        *,
        dtype: Optional[torch.dtype] = None,
        device: Optional[Union[torch.device, str, int]] = None,
23
        requires_grad: Optional[bool] = None,
24
    ) -> Mask:
25
        if isinstance(data, PIL.Image.Image):
26
            from torchvision.transforms.v2 import functional as F
27
28
29

            data = F.pil_to_tensor(data)

30
31
32
33
34
35
36
37
38
39
40
        tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
        return cls._wrap(tensor)

    @classmethod
    def wrap_like(
        cls,
        other: Mask,
        tensor: torch.Tensor,
    ) -> Mask:
        return cls._wrap(tensor)

41
    @property
42
    def spatial_size(self) -> Tuple[int, int]:
43
        return tuple(self.shape[-2:])  # type: ignore[return-value]
44

45
    def horizontal_flip(self) -> Mask:
46
        output = self._F.horizontal_flip_mask(self.as_subclass(torch.Tensor))
47
        return Mask.wrap_like(self, output)
48

49
    def vertical_flip(self) -> Mask:
50
        output = self._F.vertical_flip_mask(self.as_subclass(torch.Tensor))
51
        return Mask.wrap_like(self, output)
52
53
54
55

    def resize(  # type: ignore[override]
        self,
        size: List[int],
56
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
57
        max_size: Optional[int] = None,
58
        antialias: Optional[Union[str, bool]] = "warn",
59
    ) -> Mask:
60
        output = self._F.resize_mask(self.as_subclass(torch.Tensor), size, max_size=max_size)
61
        return Mask.wrap_like(self, output)
62

63
    def crop(self, top: int, left: int, height: int, width: int) -> Mask:
64
        output = self._F.crop_mask(self.as_subclass(torch.Tensor), top, left, height, width)
65
        return Mask.wrap_like(self, output)
66

67
    def center_crop(self, output_size: List[int]) -> Mask:
68
        output = self._F.center_crop_mask(self.as_subclass(torch.Tensor), output_size=output_size)
69
        return Mask.wrap_like(self, output)
70
71
72
73
74
75
76
77

    def resized_crop(
        self,
        top: int,
        left: int,
        height: int,
        width: int,
        size: List[int],
78
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
79
        antialias: Optional[Union[str, bool]] = "warn",
80
    ) -> Mask:
81
        output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size)
82
        return Mask.wrap_like(self, output)
83
84

    def pad(
85
        self,
86
87
        padding: List[int],
        fill: Optional[Union[int, float, List[float]]] = None,
88
        padding_mode: str = "constant",
89
    ) -> Mask:
90
        output = self._F.pad_mask(self.as_subclass(torch.Tensor), padding, padding_mode=padding_mode, fill=fill)
91
        return Mask.wrap_like(self, output)
92
93
94
95

    def rotate(
        self,
        angle: float,
96
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
97
98
        expand: bool = False,
        center: Optional[List[float]] = None,
Philip Meier's avatar
Philip Meier committed
99
        fill: _FillTypeJIT = None,
100
    ) -> Mask:
101
        output = self._F.rotate_mask(self.as_subclass(torch.Tensor), angle, expand=expand, center=center, fill=fill)
102
        return Mask.wrap_like(self, output)
103
104
105

    def affine(
        self,
106
        angle: Union[int, float],
107
108
109
        translate: List[float],
        scale: float,
        shear: List[float],
110
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
111
        fill: _FillTypeJIT = None,
112
        center: Optional[List[float]] = None,
113
114
    ) -> Mask:
        output = self._F.affine_mask(
115
            self.as_subclass(torch.Tensor),
116
117
118
119
            angle,
            translate=translate,
            scale=scale,
            shear=shear,
120
            fill=fill,
121
122
            center=center,
        )
123
        return Mask.wrap_like(self, output)
124
125
126

    def perspective(
        self,
127
128
        startpoints: Optional[List[List[int]]],
        endpoints: Optional[List[List[int]]],
129
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
130
        fill: _FillTypeJIT = None,
131
        coefficients: Optional[List[float]] = None,
132
    ) -> Mask:
133
134
135
        output = self._F.perspective_mask(
            self.as_subclass(torch.Tensor), startpoints, endpoints, fill=fill, coefficients=coefficients
        )
136
        return Mask.wrap_like(self, output)
137
138
139
140

    def elastic(
        self,
        displacement: torch.Tensor,
141
        interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Philip Meier's avatar
Philip Meier committed
142
        fill: _FillTypeJIT = None,
143
    ) -> Mask:
144
        output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)
145
        return Mask.wrap_like(self, output)