deeplabv3.py 11.9 KB
Newer Older
1
2
from functools import partial
from typing import Any, List, Optional
3

4
5
6
7
import torch
from torch import nn
from torch.nn import functional as F

8
from ...transforms._presets import SemanticSegmentation
9
10
11
12
13
14
from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES
from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3, MobileNet_V3_Large_Weights, mobilenet_v3_large
from ..resnet import ResNet, resnet50, resnet101, ResNet50_Weights, ResNet101_Weights
from ._utils import _SimpleSegmentationModel
15
from .fcn import FCNHead
16
17


18
19
__all__ = [
    "DeepLabV3",
20
21
22
23
    "DeepLabV3_ResNet50_Weights",
    "DeepLabV3_ResNet101_Weights",
    "DeepLabV3_MobileNet_V3_Large_Weights",
    "deeplabv3_mobilenet_v3_large",
24
25
26
27
28
    "deeplabv3_resnet50",
    "deeplabv3_resnet101",
]


29
class DeepLabV3(_SimpleSegmentationModel):
30
31
32
33
34
    """
    Implements DeepLabV3 model from
    `"Rethinking Atrous Convolution for Semantic Image Segmentation"
    <https://arxiv.org/abs/1706.05587>`_.

35
    Args:
36
37
38
39
40
41
42
43
        backbone (nn.Module): the network used to compute the features for the model.
            The backbone should return an OrderedDict[Tensor], with the key being
            "out" for the last feature map used, and "aux" if an auxiliary classifier
            is used.
        classifier (nn.Module): module that takes the "out" element returned from
            the backbone and returns a dense prediction.
        aux_classifier (nn.Module, optional): auxiliary classifier used during training
    """
44

45
46
47
48
    pass


class DeepLabHead(nn.Sequential):
49
    def __init__(self, in_channels: int, num_classes: int) -> None:
50
        super().__init__(
51
52
53
54
            ASPP(in_channels, [12, 24, 36]),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
55
            nn.Conv2d(256, num_classes, 1),
56
57
58
59
        )


class ASPPConv(nn.Sequential):
60
    def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:
61
62
63
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
64
            nn.ReLU(),
65
        ]
66
        super().__init__(*modules)
67
68
69


class ASPPPooling(nn.Sequential):
70
    def __init__(self, in_channels: int, out_channels: int) -> None:
71
        super().__init__(
72
73
74
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
75
76
            nn.ReLU(),
        )
77

78
    def forward(self, x: torch.Tensor) -> torch.Tensor:
79
        size = x.shape[-2:]
eellison's avatar
eellison committed
80
81
        for mod in self:
            x = mod(x)
82
        return F.interpolate(x, size=size, mode="bilinear", align_corners=False)
83
84
85


class ASPP(nn.Module):
86
    def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:
87
        super().__init__()
88
        modules = []
89
90
91
        modules.append(
            nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU())
        )
92

93
94
95
96
        rates = tuple(atrous_rates)
        for rate in rates:
            modules.append(ASPPConv(in_channels, out_channels, rate))

97
98
99
100
101
        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
Adeel Hassan's avatar
Adeel Hassan committed
102
            nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
103
104
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
105
106
            nn.Dropout(0.5),
        )
107

108
109
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _res = []
110
        for conv in self.convs:
111
112
            _res.append(conv(x))
        res = torch.cat(_res, dim=1)
113
        return self.project(res)
114
115
116


def _deeplabv3_resnet(
117
    backbone: ResNet,
118
119
120
121
122
123
    num_classes: int,
    aux: Optional[bool],
) -> DeepLabV3:
    return_layers = {"layer4": "out"}
    if aux:
        return_layers["layer3"] = "aux"
124
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
125
126
127
128
129
130

    aux_classifier = FCNHead(1024, num_classes) if aux else None
    classifier = DeepLabHead(2048, num_classes)
    return DeepLabV3(backbone, classifier, aux_classifier)


131
132
_COMMON_META = {
    "categories": _VOC_CATEGORIES,
133
    "min_size": (1, 1),
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
}


class DeepLabV3_ResNet50_Weights(WeightsEnum):
    COCO_WITH_VOC_LABELS_V1 = Weights(
        url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
        transforms=partial(SemanticSegmentation, resize_size=520),
        meta={
            **_COMMON_META,
            "num_params": 42004074,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50",
            "mIoU": 66.4,
            "acc": 92.4,
        },
    )
    DEFAULT = COCO_WITH_VOC_LABELS_V1


class DeepLabV3_ResNet101_Weights(WeightsEnum):
    COCO_WITH_VOC_LABELS_V1 = Weights(
        url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
        transforms=partial(SemanticSegmentation, resize_size=520),
        meta={
            **_COMMON_META,
            "num_params": 60996202,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101",
            "mIoU": 67.4,
            "acc": 92.4,
        },
    )
    DEFAULT = COCO_WITH_VOC_LABELS_V1


class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
    COCO_WITH_VOC_LABELS_V1 = Weights(
        url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
        transforms=partial(SemanticSegmentation, resize_size=520),
        meta={
            **_COMMON_META,
            "num_params": 11029328,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large",
            "mIoU": 60.3,
            "acc": 91.2,
        },
    )
    DEFAULT = COCO_WITH_VOC_LABELS_V1


182
def _deeplabv3_mobilenetv3(
183
    backbone: MobileNetV3,
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    num_classes: int,
    aux: Optional[bool],
) -> DeepLabV3:
    backbone = backbone.features
    # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
    # The first and last blocks are always included because they are the C0 (conv1) and Cn.
    stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
    out_pos = stage_indices[-1]  # use C5 which has output_stride = 16
    out_inplanes = backbone[out_pos].out_channels
    aux_pos = stage_indices[-4]  # use C2 here which has output_stride = 8
    aux_inplanes = backbone[aux_pos].out_channels
    return_layers = {str(out_pos): "out"}
    if aux:
        return_layers[str(aux_pos)] = "aux"
198
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
199
200
201
202
203
204

    aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None
    classifier = DeepLabHead(out_inplanes, num_classes)
    return DeepLabV3(backbone, classifier, aux_classifier)


205
206
207
208
@handle_legacy_interface(
    weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1),
    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
209
def deeplabv3_resnet50(
210
211
    *,
    weights: Optional[DeepLabV3_ResNet50_Weights] = None,
212
    progress: bool = True,
213
    num_classes: Optional[int] = None,
214
    aux_loss: Optional[bool] = None,
215
216
    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
    **kwargs: Any,
217
218
219
220
) -> DeepLabV3:
    """Constructs a DeepLabV3 model with a ResNet-50 backbone.

    Args:
221
        weights (DeepLabV3_ResNet50_Weights, optional): The pretrained weights for the model
222
        progress (bool): If True, displays a progress bar of the download to stderr
223
        num_classes (int, optional): number of output classes of the model (including the background)
224
        aux_loss (bool, optional): If True, it uses an auxiliary loss
225
        weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
226
    """
227
228
    weights = DeepLabV3_ResNet50_Weights.verify(weights)
    weights_backbone = ResNet50_Weights.verify(weights_backbone)
229

230
231
232
233
234
235
236
237
    if weights is not None:
        weights_backbone = None
        num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
        aux_loss = _ovewrite_value_param(aux_loss, True)
    elif num_classes is None:
        num_classes = 21

    backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
238
239
    model = _deeplabv3_resnet(backbone, num_classes, aux_loss)

240
241
242
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

243
244
245
    return model


246
247
248
249
@handle_legacy_interface(
    weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1),
    weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1),
)
250
def deeplabv3_resnet101(
251
252
    *,
    weights: Optional[DeepLabV3_ResNet101_Weights] = None,
253
    progress: bool = True,
254
    num_classes: Optional[int] = None,
255
    aux_loss: Optional[bool] = None,
256
257
    weights_backbone: Optional[ResNet101_Weights] = ResNet101_Weights.IMAGENET1K_V1,
    **kwargs: Any,
258
259
260
261
) -> DeepLabV3:
    """Constructs a DeepLabV3 model with a ResNet-101 backbone.

    Args:
262
        weights (DeepLabV3_ResNet101_Weights, optional): The pretrained weights for the model
263
264
265
        progress (bool): If True, displays a progress bar of the download to stderr
        num_classes (int): The number of classes
        aux_loss (bool, optional): If True, include an auxiliary classifier
266
        weights_backbone (ResNet101_Weights, optional): The pretrained weights for the backbone
267
    """
268
269
    weights = DeepLabV3_ResNet101_Weights.verify(weights)
    weights_backbone = ResNet101_Weights.verify(weights_backbone)
270

271
272
273
274
275
276
277
278
    if weights is not None:
        weights_backbone = None
        num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
        aux_loss = _ovewrite_value_param(aux_loss, True)
    elif num_classes is None:
        num_classes = 21

    backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
279
280
    model = _deeplabv3_resnet(backbone, num_classes, aux_loss)

281
282
283
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

284
285
286
    return model


287
288
289
290
@handle_legacy_interface(
    weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1),
    weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
291
def deeplabv3_mobilenet_v3_large(
292
293
    *,
    weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None,
294
    progress: bool = True,
295
    num_classes: Optional[int] = None,
296
    aux_loss: Optional[bool] = None,
297
298
    weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
    **kwargs: Any,
299
300
301
302
) -> DeepLabV3:
    """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.

    Args:
303
        weights (DeepLabV3_MobileNet_V3_Large_Weights, optional): The pretrained weights for the model
304
        progress (bool): If True, displays a progress bar of the download to stderr
305
        num_classes (int, optional): number of output classes of the model (including the background)
306
        aux_loss (bool, optional): If True, it uses an auxiliary loss
307
        weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone
308
    """
309
310
    weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights)
    weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
311

312
313
314
315
316
317
318
319
    if weights is not None:
        weights_backbone = None
        num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
        aux_loss = _ovewrite_value_param(aux_loss, True)
    elif num_classes is None:
        num_classes = 21

    backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
320
321
    model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)

322
323
324
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

325
    return model