"vscode:/vscode.git/clone" did not exist on "e91495504a81af712a17020018b1dba4a6cf781f"
fcn.py 8.43 KB
Newer Older
1
2
from functools import partial
from typing import Any, Optional
3

4
5
from torch import nn

6
from ...transforms._presets import SemanticSegmentation
7
8
9
10
11
from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES
from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet, ResNet50_Weights, ResNet101_Weights, resnet50, resnet101
from ._utils import _SimpleSegmentationModel
12
13


14
__all__ = ["FCN", "FCN_ResNet50_Weights", "FCN_ResNet101_Weights", "fcn_resnet50", "fcn_resnet101"]
15
16


17
class FCN(_SimpleSegmentationModel):
18
    """
19
20
21
    Implements FCN model from
    `"Fully Convolutional Networks for Semantic Segmentation"
    <https://arxiv.org/abs/1411.4038>`_.
22

23
    Args:
24
25
26
27
28
29
30
31
        backbone (nn.Module): the network used to compute the features for the model.
            The backbone should return an OrderedDict[Tensor], with the key being
            "out" for the last feature map used, and "aux" if an auxiliary classifier
            is used.
        classifier (nn.Module): module that takes the "out" element returned from
            the backbone and returns a dense prediction.
        aux_classifier (nn.Module, optional): auxiliary classifier used during training
    """
32

33
34
35
36
    pass


class FCNHead(nn.Sequential):
37
    def __init__(self, in_channels: int, channels: int) -> None:
38
39
40
41
42
43
        inter_channels = in_channels // 4
        layers = [
            nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(),
            nn.Dropout(0.1),
44
            nn.Conv2d(inter_channels, channels, 1),
45
46
        ]

47
        super().__init__(*layers)
48
49


50
51
_COMMON_META = {
    "categories": _VOC_CATEGORIES,
52
    "min_size": (1, 1),
53
54
55
56
57
58
59
60
61
62
63
}


class FCN_ResNet50_Weights(WeightsEnum):
    COCO_WITH_VOC_LABELS_V1 = Weights(
        url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
        transforms=partial(SemanticSegmentation, resize_size=520),
        meta={
            **_COMMON_META,
            "num_params": 35322218,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet50",
64
65
66
67
            "metrics": {
                "miou": 60.5,
                "pixel_acc": 91.4,
            },
68
69
70
71
72
73
74
75
76
77
78
79
80
        },
    )
    DEFAULT = COCO_WITH_VOC_LABELS_V1


class FCN_ResNet101_Weights(WeightsEnum):
    COCO_WITH_VOC_LABELS_V1 = Weights(
        url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
        transforms=partial(SemanticSegmentation, resize_size=520),
        meta={
            **_COMMON_META,
            "num_params": 54314346,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet101",
81
82
83
84
            "metrics": {
                "miou": 63.7,
                "pixel_acc": 91.9,
            },
85
86
87
88
89
        },
    )
    DEFAULT = COCO_WITH_VOC_LABELS_V1


90
def _fcn_resnet(
91
    backbone: ResNet,
92
93
94
95
96
97
    num_classes: int,
    aux: Optional[bool],
) -> FCN:
    return_layers = {"layer4": "out"}
    if aux:
        return_layers["layer3"] = "aux"
98
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
99
100
101
102
103
104

    aux_classifier = FCNHead(1024, num_classes) if aux else None
    classifier = FCNHead(2048, num_classes)
    return FCN(backbone, classifier, aux_classifier)


105
106
107
108
@handle_legacy_interface(
    weights=("pretrained", FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1),
    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
109
def fcn_resnet50(
110
111
    *,
    weights: Optional[FCN_ResNet50_Weights] = None,
112
    progress: bool = True,
113
    num_classes: Optional[int] = None,
114
    aux_loss: Optional[bool] = None,
115
116
    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
    **kwargs: Any,
117
) -> FCN:
118
119
    """Fully-Convolutional Network model with a ResNet-50 backbone from the `Fully Convolutional
    Networks for Semantic Segmentation <https://arxiv.org/abs/1411.4038>`_ paper.
120
121

    Args:
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        weights (:class:`~torchvision.models.segmentation.FCN_ResNet50_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.segmentation.FCN_ResNet50_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        num_classes (int, optional): number of output classes of the model (including the background).
        aux_loss (bool, optional): If True, it uses an auxiliary loss.
        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained
            weights for the backbone.
        **kwargs: parameters passed to the ``torchvision.models.segmentation.fcn.FCN``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/fcn.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.segmentation.FCN_ResNet50_Weights
        :members:
140
    """
141

142
143
144
145
146
147
148
149
150
    weights = FCN_ResNet50_Weights.verify(weights)
    weights_backbone = ResNet50_Weights.verify(weights_backbone)

    if weights is not None:
        weights_backbone = None
        num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
        aux_loss = _ovewrite_value_param(aux_loss, True)
    elif num_classes is None:
        num_classes = 21
151

152
    backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
153
154
    model = _fcn_resnet(backbone, num_classes, aux_loss)

155
156
157
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

158
159
160
    return model


161
162
163
164
@handle_legacy_interface(
    weights=("pretrained", FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1),
    weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1),
)
165
def fcn_resnet101(
166
167
    *,
    weights: Optional[FCN_ResNet101_Weights] = None,
168
    progress: bool = True,
169
    num_classes: Optional[int] = None,
170
    aux_loss: Optional[bool] = None,
171
172
    weights_backbone: Optional[ResNet101_Weights] = ResNet101_Weights.IMAGENET1K_V1,
    **kwargs: Any,
173
) -> FCN:
174
175
    """Fully-Convolutional Network model with a ResNet-101 backbone from the `Fully Convolutional
    Networks for Semantic Segmentation <https://arxiv.org/abs/1411.4038>`_ paper.
176
177

    Args:
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        weights (:class:`~torchvision.models.segmentation.FCN_ResNet101_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.segmentation.FCN_ResNet101_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        num_classes (int, optional): number of output classes of the model (including the background).
        aux_loss (bool, optional): If True, it uses an auxiliary loss.
        weights_backbone (:class:`~torchvision.models.ResNet101_Weights`, optional): The pretrained
            weights for the backbone.
        **kwargs: parameters passed to the ``torchvision.models.segmentation.fcn.FCN``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/fcn.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.segmentation.FCN_ResNet101_Weights
        :members:
196
    """
197

198
199
200
201
202
203
204
205
206
    weights = FCN_ResNet101_Weights.verify(weights)
    weights_backbone = ResNet101_Weights.verify(weights_backbone)

    if weights is not None:
        weights_backbone = None
        num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
        aux_loss = _ovewrite_value_param(aux_loss, True)
    elif num_classes is None:
        num_classes = 21
207

208
    backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
209
210
    model = _fcn_resnet(backbone, num_classes, aux_loss)

211
212
213
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

214
    return model
215
216
217
218
219
220
221
222
223
224
225
226


# The dictionary below is internal implementation detail and will be removed in v0.15
from .._utils import _ModelURLs


model_urls = _ModelURLs(
    {
        "fcn_resnet50_coco": FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1.url,
        "fcn_resnet101_coco": FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1.url,
    }
)