_presets.py 7.83 KB
Newer Older
1
2
3
4
5
"""
This file is part of the private API. Please do not use directly these classes as they will be modified on
future versions without warning. The classes should be accessed only via the transforms argument of Weights.
"""
from typing import Optional, Tuple
6
7
8
9

import torch
from torch import Tensor, nn

10
from . import functional as F, InterpolationMode
11
12


13
__all__ = [
14
15
16
17
18
    "ObjectDetection",
    "ImageClassification",
    "VideoClassification",
    "SemanticSegmentation",
    "OpticalFlow",
19
]
20
21


22
23
class ObjectDetection(nn.Module):
    def forward(self, img: Tensor) -> Tensor:
24
25
        if not isinstance(img, Tensor):
            img = F.pil_to_tensor(img)
26
        return F.convert_image_dtype(img, torch.float)
27

28
29
30
31
    def __repr__(self) -> str:
        return self.__class__.__name__ + "()"

    def describe(self) -> str:
32
33
34
35
        return (
            "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
            "The images are rescaled to ``[0.0, 1.0]``."
        )
36

37

38
class ImageClassification(nn.Module):
39
40
    def __init__(
        self,
41
        *,
42
43
44
45
        crop_size: int,
        resize_size: int = 256,
        mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
        std: Tuple[float, ...] = (0.229, 0.224, 0.225),
46
        interpolation: InterpolationMode = InterpolationMode.BILINEAR,
47
    ) -> None:
48
        super().__init__()
49
50
51
52
53
        self.crop_size = [crop_size]
        self.resize_size = [resize_size]
        self.mean = list(mean)
        self.std = list(std)
        self.interpolation = interpolation
54
55

    def forward(self, img: Tensor) -> Tensor:
56
57
        img = F.resize(img, self.resize_size, interpolation=self.interpolation)
        img = F.center_crop(img, self.crop_size)
58
59
60
        if not isinstance(img, Tensor):
            img = F.pil_to_tensor(img)
        img = F.convert_image_dtype(img, torch.float)
61
        img = F.normalize(img, mean=self.mean, std=self.std)
62
        return img
63

64
65
66
67
68
69
70
71
72
73
74
75
    def __repr__(self) -> str:
        format_string = self.__class__.__name__ + "("
        format_string += f"\n    crop_size={self.crop_size}"
        format_string += f"\n    resize_size={self.resize_size}"
        format_string += f"\n    mean={self.mean}"
        format_string += f"\n    std={self.std}"
        format_string += f"\n    interpolation={self.interpolation}"
        format_string += "\n)"
        return format_string

    def describe(self) -> str:
        return (
76
            "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
77
            f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
78
79
            f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
            f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
80
81
        )

82

83
class VideoClassification(nn.Module):
84
85
    def __init__(
        self,
86
        *,
87
        crop_size: Tuple[int, int],
88
        resize_size: Tuple[int, int],
89
90
        mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645),
        std: Tuple[float, ...] = (0.22803, 0.22145, 0.216989),
91
        interpolation: InterpolationMode = InterpolationMode.BILINEAR,
92
93
    ) -> None:
        super().__init__()
94
95
96
97
98
        self.crop_size = list(crop_size)
        self.resize_size = list(resize_size)
        self.mean = list(mean)
        self.std = list(std)
        self.interpolation = interpolation
99
100

    def forward(self, vid: Tensor) -> Tensor:
101
102
103
104
105
106
107
        need_squeeze = False
        if vid.ndim < 5:
            vid = vid.unsqueeze(dim=0)
            need_squeeze = True

        N, T, C, H, W = vid.shape
        vid = vid.view(-1, C, H, W)
108
109
        vid = F.resize(vid, self.resize_size, interpolation=self.interpolation)
        vid = F.center_crop(vid, self.crop_size)
110
        vid = F.convert_image_dtype(vid, torch.float)
111
112
        vid = F.normalize(vid, mean=self.mean, std=self.std)
        H, W = self.crop_size
113
114
115
116
117
118
        vid = vid.view(N, T, C, H, W)
        vid = vid.permute(0, 2, 1, 3, 4)  # (N, T, C, H, W) => (N, C, T, H, W)

        if need_squeeze:
            vid = vid.squeeze(dim=0)
        return vid
119

120
121
122
123
124
125
126
127
128
129
130
131
    def __repr__(self) -> str:
        format_string = self.__class__.__name__ + "("
        format_string += f"\n    crop_size={self.crop_size}"
        format_string += f"\n    resize_size={self.resize_size}"
        format_string += f"\n    mean={self.mean}"
        format_string += f"\n    std={self.std}"
        format_string += f"\n    interpolation={self.interpolation}"
        format_string += "\n)"
        return format_string

    def describe(self) -> str:
        return (
132
133
            "Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. "
            f"The frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
134
            f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
135
136
            f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. Finally the output "
            "dimensions are permuted to ``(..., C, T, H, W)`` tensors."
137
138
        )

139

140
class SemanticSegmentation(nn.Module):
141
142
    def __init__(
        self,
143
144
        *,
        resize_size: Optional[int],
145
146
        mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
        std: Tuple[float, ...] = (0.229, 0.224, 0.225),
147
        interpolation: InterpolationMode = InterpolationMode.BILINEAR,
148
149
    ) -> None:
        super().__init__()
150
151
152
153
        self.resize_size = [resize_size] if resize_size is not None else None
        self.mean = list(mean)
        self.std = list(std)
        self.interpolation = interpolation
154

155
    def forward(self, img: Tensor) -> Tensor:
156
157
        if isinstance(self.resize_size, list):
            img = F.resize(img, self.resize_size, interpolation=self.interpolation)
158
159
160
        if not isinstance(img, Tensor):
            img = F.pil_to_tensor(img)
        img = F.convert_image_dtype(img, torch.float)
161
        img = F.normalize(img, mean=self.mean, std=self.std)
162
        return img
163

164
165
166
167
168
169
170
171
172
173
174
    def __repr__(self) -> str:
        format_string = self.__class__.__name__ + "("
        format_string += f"\n    resize_size={self.resize_size}"
        format_string += f"\n    mean={self.mean}"
        format_string += f"\n    std={self.std}"
        format_string += f"\n    interpolation={self.interpolation}"
        format_string += "\n)"
        return format_string

    def describe(self) -> str:
        return (
175
            "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
176
            f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
177
178
            f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
            f"``std={self.std}``."
179
180
        )

181

182
183
184
185
186
187
class OpticalFlow(nn.Module):
    def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]:
        if not isinstance(img1, Tensor):
            img1 = F.pil_to_tensor(img1)
        if not isinstance(img2, Tensor):
            img2 = F.pil_to_tensor(img2)
188

189
190
        img1 = F.convert_image_dtype(img1, torch.float)
        img2 = F.convert_image_dtype(img2, torch.float)
191
192
193
194
195
196
197
198

        # map [0, 1] into [-1, 1]
        img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        img2 = F.normalize(img2, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

        img1 = img1.contiguous()
        img2 = img2.contiguous()

199
        return img1, img2
200
201
202
203
204

    def __repr__(self) -> str:
        return self.__class__.__name__ + "()"

    def describe(self) -> str:
205
206
207
208
        return (
            "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
            "The images are rescaled to ``[-1.0, 1.0]``."
        )