_presets.py 8.32 KB
Newer Older
1
2
3
4
"""
This file is part of the private API. Please do not use directly these classes as they will be modified on
future versions without warning. The classes should be accessed only via the transforms argument of Weights.
"""
5
from typing import Optional, Tuple, Union
6
7

import torch
8
from torch import nn, Tensor
9

10
from . import functional as F, InterpolationMode
11
12


13
__all__ = [
14
15
16
17
18
    "ObjectDetection",
    "ImageClassification",
    "VideoClassification",
    "SemanticSegmentation",
    "OpticalFlow",
19
]
20
21


22
23
class ObjectDetection(nn.Module):
    def forward(self, img: Tensor) -> Tensor:
24
25
        if not isinstance(img, Tensor):
            img = F.pil_to_tensor(img)
26
        return F.convert_image_dtype(img, torch.float)
27

28
29
30
31
    def __repr__(self) -> str:
        return self.__class__.__name__ + "()"

    def describe(self) -> str:
32
33
34
35
        return (
            "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
            "The images are rescaled to ``[0.0, 1.0]``."
        )
36

37

38
class ImageClassification(nn.Module):
39
40
    def __init__(
        self,
41
        *,
42
43
44
45
        crop_size: int,
        resize_size: int = 256,
        mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
        std: Tuple[float, ...] = (0.229, 0.224, 0.225),
46
        interpolation: InterpolationMode = InterpolationMode.BILINEAR,
47
        antialias: Optional[Union[str, bool]] = "warn",
48
    ) -> None:
49
        super().__init__()
50
51
52
53
54
        self.crop_size = [crop_size]
        self.resize_size = [resize_size]
        self.mean = list(mean)
        self.std = list(std)
        self.interpolation = interpolation
55
        self.antialias = antialias
56
57

    def forward(self, img: Tensor) -> Tensor:
58
        img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias)
59
        img = F.center_crop(img, self.crop_size)
60
61
62
        if not isinstance(img, Tensor):
            img = F.pil_to_tensor(img)
        img = F.convert_image_dtype(img, torch.float)
63
        img = F.normalize(img, mean=self.mean, std=self.std)
64
        return img
65

66
67
68
69
70
71
72
73
74
75
76
77
    def __repr__(self) -> str:
        format_string = self.__class__.__name__ + "("
        format_string += f"\n    crop_size={self.crop_size}"
        format_string += f"\n    resize_size={self.resize_size}"
        format_string += f"\n    mean={self.mean}"
        format_string += f"\n    std={self.std}"
        format_string += f"\n    interpolation={self.interpolation}"
        format_string += "\n)"
        return format_string

    def describe(self) -> str:
        return (
78
            "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
79
            f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
80
81
            f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
            f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
82
83
        )

84

85
class VideoClassification(nn.Module):
86
87
    def __init__(
        self,
88
        *,
89
        crop_size: Tuple[int, int],
90
        resize_size: Tuple[int, int],
91
92
        mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645),
        std: Tuple[float, ...] = (0.22803, 0.22145, 0.216989),
93
        interpolation: InterpolationMode = InterpolationMode.BILINEAR,
94
95
    ) -> None:
        super().__init__()
96
97
98
99
100
        self.crop_size = list(crop_size)
        self.resize_size = list(resize_size)
        self.mean = list(mean)
        self.std = list(std)
        self.interpolation = interpolation
101
102

    def forward(self, vid: Tensor) -> Tensor:
103
104
105
106
107
108
109
        need_squeeze = False
        if vid.ndim < 5:
            vid = vid.unsqueeze(dim=0)
            need_squeeze = True

        N, T, C, H, W = vid.shape
        vid = vid.view(-1, C, H, W)
110
111
112
113
114
        # We hard-code antialias=False to preserve results after we changed
        # its default from None to True (see
        # https://github.com/pytorch/vision/pull/7160)
        # TODO: we could re-train the video models with antialias=True?
        vid = F.resize(vid, self.resize_size, interpolation=self.interpolation, antialias=False)
115
        vid = F.center_crop(vid, self.crop_size)
116
        vid = F.convert_image_dtype(vid, torch.float)
117
118
        vid = F.normalize(vid, mean=self.mean, std=self.std)
        H, W = self.crop_size
119
120
121
122
123
124
        vid = vid.view(N, T, C, H, W)
        vid = vid.permute(0, 2, 1, 3, 4)  # (N, T, C, H, W) => (N, C, T, H, W)

        if need_squeeze:
            vid = vid.squeeze(dim=0)
        return vid
125

126
127
128
129
130
131
132
133
134
135
136
137
    def __repr__(self) -> str:
        format_string = self.__class__.__name__ + "("
        format_string += f"\n    crop_size={self.crop_size}"
        format_string += f"\n    resize_size={self.resize_size}"
        format_string += f"\n    mean={self.mean}"
        format_string += f"\n    std={self.std}"
        format_string += f"\n    interpolation={self.interpolation}"
        format_string += "\n)"
        return format_string

    def describe(self) -> str:
        return (
138
139
            "Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. "
            f"The frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
140
            f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
141
142
            f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. Finally the output "
            "dimensions are permuted to ``(..., C, T, H, W)`` tensors."
143
144
        )

145

146
class SemanticSegmentation(nn.Module):
147
148
    def __init__(
        self,
149
150
        *,
        resize_size: Optional[int],
151
152
        mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
        std: Tuple[float, ...] = (0.229, 0.224, 0.225),
153
        interpolation: InterpolationMode = InterpolationMode.BILINEAR,
154
        antialias: Optional[Union[str, bool]] = "warn",
155
156
    ) -> None:
        super().__init__()
157
158
159
160
        self.resize_size = [resize_size] if resize_size is not None else None
        self.mean = list(mean)
        self.std = list(std)
        self.interpolation = interpolation
161
        self.antialias = antialias
162

163
    def forward(self, img: Tensor) -> Tensor:
164
        if isinstance(self.resize_size, list):
165
            img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias)
166
167
168
        if not isinstance(img, Tensor):
            img = F.pil_to_tensor(img)
        img = F.convert_image_dtype(img, torch.float)
169
        img = F.normalize(img, mean=self.mean, std=self.std)
170
        return img
171

172
173
174
175
176
177
178
179
180
181
182
    def __repr__(self) -> str:
        format_string = self.__class__.__name__ + "("
        format_string += f"\n    resize_size={self.resize_size}"
        format_string += f"\n    mean={self.mean}"
        format_string += f"\n    std={self.std}"
        format_string += f"\n    interpolation={self.interpolation}"
        format_string += "\n)"
        return format_string

    def describe(self) -> str:
        return (
183
            "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
184
            f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
185
186
            f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
            f"``std={self.std}``."
187
188
        )

189

190
191
192
193
194
195
class OpticalFlow(nn.Module):
    def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]:
        if not isinstance(img1, Tensor):
            img1 = F.pil_to_tensor(img1)
        if not isinstance(img2, Tensor):
            img2 = F.pil_to_tensor(img2)
196

197
198
        img1 = F.convert_image_dtype(img1, torch.float)
        img2 = F.convert_image_dtype(img2, torch.float)
199
200
201
202
203
204
205
206

        # map [0, 1] into [-1, 1]
        img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        img2 = F.normalize(img2, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

        img1 = img1.contiguous()
        img2 = img2.contiguous()

207
        return img1, img2
208
209
210
211
212

    def __repr__(self) -> str:
        return self.__class__.__name__ + "()"

    def describe(self) -> str:
213
214
215
216
        return (
            "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
            "The images are rescaled to ``[-1.0, 1.0]``."
        )