deeplabv3.py 12 KB
Newer Older
1
2
from functools import partial
from typing import Any, List, Optional
3

4
5
6
7
import torch
from torch import nn
from torch.nn import functional as F

8
9
10
11
12
13
14
from ...transforms._presets import SemanticSegmentation, InterpolationMode
from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES
from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3, MobileNet_V3_Large_Weights, mobilenet_v3_large
from ..resnet import ResNet, resnet50, resnet101, ResNet50_Weights, ResNet101_Weights
from ._utils import _SimpleSegmentationModel
15
from .fcn import FCNHead
16
17


18
19
__all__ = [
    "DeepLabV3",
20
21
22
23
    "DeepLabV3_ResNet50_Weights",
    "DeepLabV3_ResNet101_Weights",
    "DeepLabV3_MobileNet_V3_Large_Weights",
    "deeplabv3_mobilenet_v3_large",
24
25
26
27
28
    "deeplabv3_resnet50",
    "deeplabv3_resnet101",
]


29
class DeepLabV3(_SimpleSegmentationModel):
30
31
32
33
34
    """
    Implements DeepLabV3 model from
    `"Rethinking Atrous Convolution for Semantic Image Segmentation"
    <https://arxiv.org/abs/1706.05587>`_.

35
    Args:
36
37
38
39
40
41
42
43
        backbone (nn.Module): the network used to compute the features for the model.
            The backbone should return an OrderedDict[Tensor], with the key being
            "out" for the last feature map used, and "aux" if an auxiliary classifier
            is used.
        classifier (nn.Module): module that takes the "out" element returned from
            the backbone and returns a dense prediction.
        aux_classifier (nn.Module, optional): auxiliary classifier used during training
    """
44

45
46
47
48
    pass


class DeepLabHead(nn.Sequential):
49
    def __init__(self, in_channels: int, num_classes: int) -> None:
50
        super().__init__(
51
52
53
54
            ASPP(in_channels, [12, 24, 36]),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
55
            nn.Conv2d(256, num_classes, 1),
56
57
58
59
        )


class ASPPConv(nn.Sequential):
60
    def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:
61
62
63
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
64
            nn.ReLU(),
65
        ]
66
        super().__init__(*modules)
67
68
69


class ASPPPooling(nn.Sequential):
70
    def __init__(self, in_channels: int, out_channels: int) -> None:
71
        super().__init__(
72
73
74
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
75
76
            nn.ReLU(),
        )
77

78
    def forward(self, x: torch.Tensor) -> torch.Tensor:
79
        size = x.shape[-2:]
eellison's avatar
eellison committed
80
81
        for mod in self:
            x = mod(x)
82
        return F.interpolate(x, size=size, mode="bilinear", align_corners=False)
83
84
85


class ASPP(nn.Module):
86
    def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:
87
        super().__init__()
88
        modules = []
89
90
91
        modules.append(
            nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU())
        )
92

93
94
95
96
        rates = tuple(atrous_rates)
        for rate in rates:
            modules.append(ASPPConv(in_channels, out_channels, rate))

97
98
99
100
101
        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
Adeel Hassan's avatar
Adeel Hassan committed
102
            nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
103
104
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
105
106
            nn.Dropout(0.5),
        )
107

108
109
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _res = []
110
        for conv in self.convs:
111
112
            _res.append(conv(x))
        res = torch.cat(_res, dim=1)
113
        return self.project(res)
114
115
116


def _deeplabv3_resnet(
117
    backbone: ResNet,
118
119
120
121
122
123
    num_classes: int,
    aux: Optional[bool],
) -> DeepLabV3:
    return_layers = {"layer4": "out"}
    if aux:
        return_layers["layer3"] = "aux"
124
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
125
126
127
128
129
130

    aux_classifier = FCNHead(1024, num_classes) if aux else None
    classifier = DeepLabHead(2048, num_classes)
    return DeepLabV3(backbone, classifier, aux_classifier)


131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
_COMMON_META = {
    "task": "image_semantic_segmentation",
    "architecture": "DeepLabV3",
    "publication_year": 2017,
    "categories": _VOC_CATEGORIES,
    "interpolation": InterpolationMode.BILINEAR,
}


class DeepLabV3_ResNet50_Weights(WeightsEnum):
    COCO_WITH_VOC_LABELS_V1 = Weights(
        url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
        transforms=partial(SemanticSegmentation, resize_size=520),
        meta={
            **_COMMON_META,
            "num_params": 42004074,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50",
            "mIoU": 66.4,
            "acc": 92.4,
        },
    )
    DEFAULT = COCO_WITH_VOC_LABELS_V1


class DeepLabV3_ResNet101_Weights(WeightsEnum):
    COCO_WITH_VOC_LABELS_V1 = Weights(
        url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
        transforms=partial(SemanticSegmentation, resize_size=520),
        meta={
            **_COMMON_META,
            "num_params": 60996202,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101",
            "mIoU": 67.4,
            "acc": 92.4,
        },
    )
    DEFAULT = COCO_WITH_VOC_LABELS_V1


class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
    COCO_WITH_VOC_LABELS_V1 = Weights(
        url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
        transforms=partial(SemanticSegmentation, resize_size=520),
        meta={
            **_COMMON_META,
            "num_params": 11029328,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large",
            "mIoU": 60.3,
            "acc": 91.2,
        },
    )
    DEFAULT = COCO_WITH_VOC_LABELS_V1


185
def _deeplabv3_mobilenetv3(
186
    backbone: MobileNetV3,
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    num_classes: int,
    aux: Optional[bool],
) -> DeepLabV3:
    backbone = backbone.features
    # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
    # The first and last blocks are always included because they are the C0 (conv1) and Cn.
    stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
    out_pos = stage_indices[-1]  # use C5 which has output_stride = 16
    out_inplanes = backbone[out_pos].out_channels
    aux_pos = stage_indices[-4]  # use C2 here which has output_stride = 8
    aux_inplanes = backbone[aux_pos].out_channels
    return_layers = {str(out_pos): "out"}
    if aux:
        return_layers[str(aux_pos)] = "aux"
201
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
202
203
204
205
206
207

    aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None
    classifier = DeepLabHead(out_inplanes, num_classes)
    return DeepLabV3(backbone, classifier, aux_classifier)


208
209
210
211
@handle_legacy_interface(
    weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1),
    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
212
def deeplabv3_resnet50(
213
214
    *,
    weights: Optional[DeepLabV3_ResNet50_Weights] = None,
215
    progress: bool = True,
216
    num_classes: Optional[int] = None,
217
    aux_loss: Optional[bool] = None,
218
219
    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
    **kwargs: Any,
220
221
222
223
) -> DeepLabV3:
    """Constructs a DeepLabV3 model with a ResNet-50 backbone.

    Args:
224
        weights (DeepLabV3_ResNet50_Weights, optional): The pretrained weights for the model
225
        progress (bool): If True, displays a progress bar of the download to stderr
226
        num_classes (int, optional): number of output classes of the model (including the background)
227
        aux_loss (bool, optional): If True, it uses an auxiliary loss
228
        weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
229
    """
230
231
    weights = DeepLabV3_ResNet50_Weights.verify(weights)
    weights_backbone = ResNet50_Weights.verify(weights_backbone)
232

233
234
235
236
237
238
239
240
    if weights is not None:
        weights_backbone = None
        num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
        aux_loss = _ovewrite_value_param(aux_loss, True)
    elif num_classes is None:
        num_classes = 21

    backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
241
242
    model = _deeplabv3_resnet(backbone, num_classes, aux_loss)

243
244
245
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

246
247
248
    return model


249
250
251
252
@handle_legacy_interface(
    weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1),
    weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1),
)
253
def deeplabv3_resnet101(
254
255
    *,
    weights: Optional[DeepLabV3_ResNet101_Weights] = None,
256
    progress: bool = True,
257
    num_classes: Optional[int] = None,
258
    aux_loss: Optional[bool] = None,
259
260
    weights_backbone: Optional[ResNet101_Weights] = ResNet101_Weights.IMAGENET1K_V1,
    **kwargs: Any,
261
262
263
264
) -> DeepLabV3:
    """Constructs a DeepLabV3 model with a ResNet-101 backbone.

    Args:
265
        weights (DeepLabV3_ResNet101_Weights, optional): The pretrained weights for the model
266
267
268
        progress (bool): If True, displays a progress bar of the download to stderr
        num_classes (int): The number of classes
        aux_loss (bool, optional): If True, include an auxiliary classifier
269
        weights_backbone (ResNet101_Weights, optional): The pretrained weights for the backbone
270
    """
271
272
    weights = DeepLabV3_ResNet101_Weights.verify(weights)
    weights_backbone = ResNet101_Weights.verify(weights_backbone)
273

274
275
276
277
278
279
280
281
    if weights is not None:
        weights_backbone = None
        num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
        aux_loss = _ovewrite_value_param(aux_loss, True)
    elif num_classes is None:
        num_classes = 21

    backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
282
283
    model = _deeplabv3_resnet(backbone, num_classes, aux_loss)

284
285
286
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

287
288
289
    return model


290
291
292
293
@handle_legacy_interface(
    weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1),
    weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
294
def deeplabv3_mobilenet_v3_large(
295
296
    *,
    weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None,
297
    progress: bool = True,
298
    num_classes: Optional[int] = None,
299
    aux_loss: Optional[bool] = None,
300
301
    weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
    **kwargs: Any,
302
303
304
305
) -> DeepLabV3:
    """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.

    Args:
306
        weights (DeepLabV3_MobileNet_V3_Large_Weights, optional): The pretrained weights for the model
307
        progress (bool): If True, displays a progress bar of the download to stderr
308
        num_classes (int, optional): number of output classes of the model (including the background)
309
        aux_loss (bool, optional): If True, it uses an auxiliary loss
310
        weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone
311
    """
312
313
    weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights)
    weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
314

315
316
317
318
319
320
321
322
    if weights is not None:
        weights_backbone = None
        num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
        aux_loss = _ovewrite_value_param(aux_loss, True)
    elif num_classes is None:
        num_classes = 21

    backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
323
324
    model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)

325
326
327
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

328
    return model