deeplabv3.py 14.4 KB
Newer Older
1
2
from functools import partial
from typing import Any, List, Optional
3

4
5
6
7
import torch
from torch import nn
from torch.nn import functional as F

8
from ...transforms._presets import SemanticSegmentation
9
10
11
12
13
14
from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES
from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3, MobileNet_V3_Large_Weights, mobilenet_v3_large
from ..resnet import ResNet, resnet50, resnet101, ResNet50_Weights, ResNet101_Weights
from ._utils import _SimpleSegmentationModel
15
from .fcn import FCNHead
16
17


18
19
__all__ = [
    "DeepLabV3",
20
21
22
23
    "DeepLabV3_ResNet50_Weights",
    "DeepLabV3_ResNet101_Weights",
    "DeepLabV3_MobileNet_V3_Large_Weights",
    "deeplabv3_mobilenet_v3_large",
24
25
26
27
28
    "deeplabv3_resnet50",
    "deeplabv3_resnet101",
]


29
class DeepLabV3(_SimpleSegmentationModel):
30
31
32
33
34
    """
    Implements DeepLabV3 model from
    `"Rethinking Atrous Convolution for Semantic Image Segmentation"
    <https://arxiv.org/abs/1706.05587>`_.

35
    Args:
36
37
38
39
40
41
42
43
        backbone (nn.Module): the network used to compute the features for the model.
            The backbone should return an OrderedDict[Tensor], with the key being
            "out" for the last feature map used, and "aux" if an auxiliary classifier
            is used.
        classifier (nn.Module): module that takes the "out" element returned from
            the backbone and returns a dense prediction.
        aux_classifier (nn.Module, optional): auxiliary classifier used during training
    """
44

45
46
47
48
    pass


class DeepLabHead(nn.Sequential):
49
    def __init__(self, in_channels: int, num_classes: int) -> None:
50
        super().__init__(
51
52
53
54
            ASPP(in_channels, [12, 24, 36]),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
55
            nn.Conv2d(256, num_classes, 1),
56
57
58
59
        )


class ASPPConv(nn.Sequential):
60
    def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:
61
62
63
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
64
            nn.ReLU(),
65
        ]
66
        super().__init__(*modules)
67
68
69


class ASPPPooling(nn.Sequential):
70
    def __init__(self, in_channels: int, out_channels: int) -> None:
71
        super().__init__(
72
73
74
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
75
76
            nn.ReLU(),
        )
77

78
    def forward(self, x: torch.Tensor) -> torch.Tensor:
79
        size = x.shape[-2:]
eellison's avatar
eellison committed
80
81
        for mod in self:
            x = mod(x)
82
        return F.interpolate(x, size=size, mode="bilinear", align_corners=False)
83
84
85


class ASPP(nn.Module):
86
    def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:
87
        super().__init__()
88
        modules = []
89
90
91
        modules.append(
            nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU())
        )
92

93
94
95
96
        rates = tuple(atrous_rates)
        for rate in rates:
            modules.append(ASPPConv(in_channels, out_channels, rate))

97
98
99
100
101
        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
Adeel Hassan's avatar
Adeel Hassan committed
102
            nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
103
104
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
105
106
            nn.Dropout(0.5),
        )
107

108
109
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _res = []
110
        for conv in self.convs:
111
112
            _res.append(conv(x))
        res = torch.cat(_res, dim=1)
113
        return self.project(res)
114
115
116


def _deeplabv3_resnet(
117
    backbone: ResNet,
118
119
120
121
122
123
    num_classes: int,
    aux: Optional[bool],
) -> DeepLabV3:
    return_layers = {"layer4": "out"}
    if aux:
        return_layers["layer3"] = "aux"
124
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
125
126
127
128
129
130

    aux_classifier = FCNHead(1024, num_classes) if aux else None
    classifier = DeepLabHead(2048, num_classes)
    return DeepLabV3(backbone, classifier, aux_classifier)


131
132
_COMMON_META = {
    "categories": _VOC_CATEGORIES,
133
    "min_size": (1, 1),
134
135
136
137
    "_docs": """
        These weights were trained on a subset of COCO, using only the 20 categories that are present in the Pascal VOC
        dataset.
    """,
138
139
140
141
142
143
144
145
146
147
148
}


class DeepLabV3_ResNet50_Weights(WeightsEnum):
    COCO_WITH_VOC_LABELS_V1 = Weights(
        url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
        transforms=partial(SemanticSegmentation, resize_size=520),
        meta={
            **_COMMON_META,
            "num_params": 42004074,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50",
149
150
151
152
            "metrics": {
                "miou": 66.4,
                "pixel_acc": 92.4,
            },
153
154
155
156
157
158
159
160
161
162
163
164
165
        },
    )
    DEFAULT = COCO_WITH_VOC_LABELS_V1


class DeepLabV3_ResNet101_Weights(WeightsEnum):
    COCO_WITH_VOC_LABELS_V1 = Weights(
        url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
        transforms=partial(SemanticSegmentation, resize_size=520),
        meta={
            **_COMMON_META,
            "num_params": 60996202,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101",
166
167
168
169
            "metrics": {
                "miou": 67.4,
                "pixel_acc": 92.4,
            },
170
171
172
173
174
175
176
177
178
179
180
181
182
        },
    )
    DEFAULT = COCO_WITH_VOC_LABELS_V1


class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
    COCO_WITH_VOC_LABELS_V1 = Weights(
        url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
        transforms=partial(SemanticSegmentation, resize_size=520),
        meta={
            **_COMMON_META,
            "num_params": 11029328,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large",
183
184
185
186
            "metrics": {
                "miou": 60.3,
                "pixel_acc": 91.2,
            },
187
188
189
190
191
        },
    )
    DEFAULT = COCO_WITH_VOC_LABELS_V1


192
def _deeplabv3_mobilenetv3(
193
    backbone: MobileNetV3,
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    num_classes: int,
    aux: Optional[bool],
) -> DeepLabV3:
    backbone = backbone.features
    # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
    # The first and last blocks are always included because they are the C0 (conv1) and Cn.
    stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
    out_pos = stage_indices[-1]  # use C5 which has output_stride = 16
    out_inplanes = backbone[out_pos].out_channels
    aux_pos = stage_indices[-4]  # use C2 here which has output_stride = 8
    aux_inplanes = backbone[aux_pos].out_channels
    return_layers = {str(out_pos): "out"}
    if aux:
        return_layers[str(aux_pos)] = "aux"
208
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
209
210
211
212
213
214

    aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None
    classifier = DeepLabHead(out_inplanes, num_classes)
    return DeepLabV3(backbone, classifier, aux_classifier)


215
216
217
218
@handle_legacy_interface(
    weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1),
    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
219
def deeplabv3_resnet50(
220
221
    *,
    weights: Optional[DeepLabV3_ResNet50_Weights] = None,
222
    progress: bool = True,
223
    num_classes: Optional[int] = None,
224
    aux_loss: Optional[bool] = None,
225
226
    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
    **kwargs: Any,
227
228
229
) -> DeepLabV3:
    """Constructs a DeepLabV3 model with a ResNet-50 backbone.

230
231
    Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__.

232
    Args:
233
234
235
236
237
238
239
        weights (:class:`~torchvision.models.segmentation.DeepLabV3_ResNet50_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.segmentation.DeepLabV3_ResNet50_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
240
        num_classes (int, optional): number of output classes of the model (including the background)
241
        aux_loss (bool, optional): If True, it uses an auxiliary loss
242
243
244
245
246
247
        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for the
            backbone
        **kwargs: unused

    .. autoclass:: torchvision.models.segmentation.DeepLabV3_ResNet50_Weights
        :members:
248
    """
249
250
    weights = DeepLabV3_ResNet50_Weights.verify(weights)
    weights_backbone = ResNet50_Weights.verify(weights_backbone)
251

252
253
254
255
256
257
258
259
    if weights is not None:
        weights_backbone = None
        num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
        aux_loss = _ovewrite_value_param(aux_loss, True)
    elif num_classes is None:
        num_classes = 21

    backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
260
261
    model = _deeplabv3_resnet(backbone, num_classes, aux_loss)

262
263
264
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

265
266
267
    return model


268
269
270
271
@handle_legacy_interface(
    weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1),
    weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1),
)
272
def deeplabv3_resnet101(
273
274
    *,
    weights: Optional[DeepLabV3_ResNet101_Weights] = None,
275
    progress: bool = True,
276
    num_classes: Optional[int] = None,
277
    aux_loss: Optional[bool] = None,
278
279
    weights_backbone: Optional[ResNet101_Weights] = ResNet101_Weights.IMAGENET1K_V1,
    **kwargs: Any,
280
281
282
) -> DeepLabV3:
    """Constructs a DeepLabV3 model with a ResNet-101 backbone.

283
284
    Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__.

285
    Args:
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        weights (:class:`~torchvision.models.segmentation.DeepLabV3_ResNet101_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.segmentation.DeepLabV3_ResNet101_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        num_classes (int, optional): number of output classes of the model (including the background)
        aux_loss (bool, optional): If True, it uses an auxiliary loss
        weights_backbone (:class:`~torchvision.models.ResNet101_Weights`, optional): The pretrained weights for the
            backbone
        **kwargs: unused

    .. autoclass:: torchvision.models.segmentation.DeepLabV3_ResNet101_Weights
        :members:
301
    """
302
303
    weights = DeepLabV3_ResNet101_Weights.verify(weights)
    weights_backbone = ResNet101_Weights.verify(weights_backbone)
304

305
306
307
308
309
310
311
312
    if weights is not None:
        weights_backbone = None
        num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
        aux_loss = _ovewrite_value_param(aux_loss, True)
    elif num_classes is None:
        num_classes = 21

    backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
313
314
    model = _deeplabv3_resnet(backbone, num_classes, aux_loss)

315
316
317
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

318
319
320
    return model


321
322
323
324
@handle_legacy_interface(
    weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1),
    weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
325
def deeplabv3_mobilenet_v3_large(
326
327
    *,
    weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None,
328
    progress: bool = True,
329
    num_classes: Optional[int] = None,
330
    aux_loss: Optional[bool] = None,
331
332
    weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
    **kwargs: Any,
333
334
335
) -> DeepLabV3:
    """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.

336
337
    Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__.

338
    Args:
339
340
341
342
343
344
345
        weights (:class:`~torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
346
        num_classes (int, optional): number of output classes of the model (including the background)
347
        aux_loss (bool, optional): If True, it uses an auxiliary loss
348
349
350
351
352
353
        weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained weights
            for the backbone
        **kwargs: unused

    .. autoclass:: torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights
        :members:
354
    """
355
356
    weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights)
    weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
357

358
359
360
361
362
363
364
365
    if weights is not None:
        weights_backbone = None
        num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
        aux_loss = _ovewrite_value_param(aux_loss, True)
    elif num_classes is None:
        num_classes = 21

    backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
366
367
    model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)

368
369
370
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

371
    return model
372
373
374
375
376
377
378
379
380
381
382
383
384


# The dictionary below is internal implementation detail and will be removed in v0.15
from .._utils import _ModelURLs


model_urls = _ModelURLs(
    {
        "deeplabv3_resnet50_coco": DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1.url,
        "deeplabv3_resnet101_coco": DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1.url,
        "deeplabv3_mobilenet_v3_large_coco": DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1.url,
    }
)