segmentation.py 9.38 KB
Newer Older
1
from typing import Any, Optional
2
3
4

from torch import nn

5
from ..._internally_replaced_utils import load_state_dict_from_url
6
from .. import mobilenetv3
7
from .. import resnet
8
from .._utils import IntermediateLayerGetter
9
10
from .deeplabv3 import DeepLabHead, DeepLabV3
from .fcn import FCN, FCNHead
11
from .lraspp import LRASPP
12
13


14
15
16
17
18
19
20
21
__all__ = [
    "fcn_resnet50",
    "fcn_resnet101",
    "deeplabv3_resnet50",
    "deeplabv3_resnet101",
    "deeplabv3_mobilenet_v3_large",
    "lraspp_mobilenet_v3_large",
]
22
23


24
model_urls = {
25
26
27
28
29
30
    "fcn_resnet50_coco": "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
    "fcn_resnet101_coco": "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
    "deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
    "deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
    "deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
    "lraspp_mobilenet_v3_large_coco": "https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
31
32
33
}


34
def _segm_model(
35
    name: str, backbone_name: str, num_classes: int, aux: Optional[bool], pretrained_backbone: bool = True
36
) -> nn.Module:
37
    if "resnet" in backbone_name:
38
        backbone = resnet.__dict__[backbone_name](
39
40
41
            pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]
        )
        out_layer = "layer4"
42
        out_inplanes = 2048
43
        aux_layer = "layer3"
44
        aux_inplanes = 1024
45
    elif "mobilenet_v3" in backbone_name:
46
        backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features
47
48
49
50
51
52
53
54
55
56
57

        # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
        # The first and last blocks are always included because they are the C0 (conv1) and Cn.
        stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
        out_pos = stage_indices[-1]  # use C5 which has output_stride = 16
        out_layer = str(out_pos)
        out_inplanes = backbone[out_pos].out_channels
        aux_pos = stage_indices[-4]  # use C2 here which has output_stride = 8
        aux_layer = str(aux_pos)
        aux_inplanes = backbone[aux_pos].out_channels
    else:
58
        raise NotImplementedError("backbone {} is not supported as of now".format(backbone_name))
59

60
    return_layers = {out_layer: "out"}
61
    if aux:
62
        return_layers[aux_layer] = "aux"
63
64
65
66
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

    aux_classifier = None
    if aux:
67
        aux_classifier = FCNHead(aux_inplanes, num_classes)
68
69

    model_map = {
70
71
        "deeplabv3": (DeepLabHead, DeepLabV3),
        "fcn": (FCNHead, FCN),
72
    }
73
    classifier = model_map[name][0](out_inplanes, num_classes)
74
75
76
77
78
79
    base_model = model_map[name][1]

    model = base_model(backbone, classifier, aux_classifier)
    return model


80
81
82
83
84
85
86
def _load_model(
    arch_type: str,
    backbone: str,
    pretrained: bool,
    progress: bool,
    num_classes: int,
    aux_loss: Optional[bool],
87
    **kwargs: Any,
88
) -> nn.Module:
89
90
    if pretrained:
        aux_loss = True
vfdev's avatar
vfdev committed
91
        kwargs["pretrained_backbone"] = False
92
    model = _segm_model(arch_type, backbone, num_classes, aux_loss, **kwargs)
93
    if pretrained:
94
95
96
97
        _load_weights(model, arch_type, backbone, progress)
    return model


98
def _load_weights(model: nn.Module, arch_type: str, backbone: str, progress: bool) -> None:
99
    arch = arch_type + "_" + backbone + "_coco"
100
101
    model_url = model_urls.get(arch, None)
    if model_url is None:
102
        raise NotImplementedError("pretrained {} is not supported as of now".format(arch))
103
104
105
106
107
    else:
        state_dict = load_state_dict_from_url(model_url, progress=progress)
        model.load_state_dict(state_dict)


108
def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_backbone: bool = True) -> LRASPP:
109
    backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features
110
111
112
113
114
115
116
117
118

    # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
    # The first and last blocks are always included because they are the C0 (conv1) and Cn.
    stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
    low_pos = stage_indices[-4]  # use C2 here which has output_stride = 8
    high_pos = stage_indices[-1]  # use C5 which has output_stride = 16
    low_channels = backbone[low_pos].out_channels
    high_channels = backbone[high_pos].out_channels

119
    backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"})
120
121

    model = LRASPP(backbone, low_channels, high_channels, num_classes)
122
123
124
    return model


125
126
127
128
129
def fcn_resnet50(
    pretrained: bool = False,
    progress: bool = True,
    num_classes: int = 21,
    aux_loss: Optional[bool] = None,
130
    **kwargs: Any,
131
) -> nn.Module:
ekka's avatar
ekka committed
132
133
134
135
136
137
    """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.

    Args:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
            contains the same classes as Pascal VOC
        progress (bool): If True, displays a progress bar of the download to stderr
138
139
        num_classes (int): number of output classes of the model (including the background)
        aux_loss (bool): If True, it uses an auxiliary loss
ekka's avatar
ekka committed
140
    """
141
    return _load_model("fcn", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs)
ekka's avatar
ekka committed
142
143


144
145
146
147
148
def fcn_resnet101(
    pretrained: bool = False,
    progress: bool = True,
    num_classes: int = 21,
    aux_loss: Optional[bool] = None,
149
    **kwargs: Any,
150
) -> nn.Module:
151
152
153
154
155
156
    """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.

    Args:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
            contains the same classes as Pascal VOC
        progress (bool): If True, displays a progress bar of the download to stderr
157
158
        num_classes (int): number of output classes of the model (including the background)
        aux_loss (bool): If True, it uses an auxiliary loss
159
    """
160
    return _load_model("fcn", "resnet101", pretrained, progress, num_classes, aux_loss, **kwargs)
161
162


163
164
165
166
167
def deeplabv3_resnet50(
    pretrained: bool = False,
    progress: bool = True,
    num_classes: int = 21,
    aux_loss: Optional[bool] = None,
168
    **kwargs: Any,
169
) -> nn.Module:
170
171
172
173
174
175
    """Constructs a DeepLabV3 model with a ResNet-50 backbone.

    Args:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
            contains the same classes as Pascal VOC
        progress (bool): If True, displays a progress bar of the download to stderr
176
177
        num_classes (int): number of output classes of the model (including the background)
        aux_loss (bool): If True, it uses an auxiliary loss
178
    """
179
    return _load_model("deeplabv3", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs)
180
181


182
183
184
185
186
def deeplabv3_resnet101(
    pretrained: bool = False,
    progress: bool = True,
    num_classes: int = 21,
    aux_loss: Optional[bool] = None,
187
    **kwargs: Any,
188
) -> nn.Module:
189
190
191
192
193
194
    """Constructs a DeepLabV3 model with a ResNet-101 backbone.

    Args:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
            contains the same classes as Pascal VOC
        progress (bool): If True, displays a progress bar of the download to stderr
195
196
        num_classes (int): The number of classes
        aux_loss (bool): If True, include an auxiliary classifier
197
    """
198
    return _load_model("deeplabv3", "resnet101", pretrained, progress, num_classes, aux_loss, **kwargs)
199
200


201
202
203
204
205
def deeplabv3_mobilenet_v3_large(
    pretrained: bool = False,
    progress: bool = True,
    num_classes: int = 21,
    aux_loss: Optional[bool] = None,
206
    **kwargs: Any,
207
) -> nn.Module:
208
209
210
211
212
213
214
215
216
    """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.

    Args:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
            contains the same classes as Pascal VOC
        progress (bool): If True, displays a progress bar of the download to stderr
        num_classes (int): number of output classes of the model (including the background)
        aux_loss (bool): If True, it uses an auxiliary loss
    """
217
    return _load_model("deeplabv3", "mobilenet_v3_large", pretrained, progress, num_classes, aux_loss, **kwargs)
218
219


220
def lraspp_mobilenet_v3_large(
221
    pretrained: bool = False, progress: bool = True, num_classes: int = 21, **kwargs: Any
222
) -> nn.Module:
223
224
225
226
227
228
229
230
231
    """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.

    Args:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
            contains the same classes as Pascal VOC
        progress (bool): If True, displays a progress bar of the download to stderr
        num_classes (int): number of output classes of the model (including the background)
    """
    if kwargs.pop("aux_loss", False):
232
        raise NotImplementedError("This model does not use auxiliary loss")
233

234
    backbone_name = "mobilenet_v3_large"
235
236
    if pretrained:
        kwargs["pretrained_backbone"] = False
237
238
239
    model = _segm_lraspp_mobilenetv3(backbone_name, num_classes, **kwargs)

    if pretrained:
240
        _load_weights(model, "lraspp", backbone_name, progress)
241
242

    return model