lraspp.py 6.18 KB
Newer Older
1
from collections import OrderedDict
2
3
from functools import partial
from typing import Any, Dict, Optional
4
5
6
7

from torch import nn, Tensor
from torch.nn import functional as F

8
from ...transforms._presets import SemanticSegmentation
9
from ...utils import _log_api_usage_once
10
11
12
13
from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES
from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3, MobileNet_V3_Large_Weights, mobilenet_v3_large
14

15

16
__all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"]
17
18
19
20
21
22
23
24
25
26
27
28
29
30


class LRASPP(nn.Module):
    """
    Implements a Lite R-ASPP Network for semantic segmentation from
    `"Searching for MobileNetV3"
    <https://arxiv.org/abs/1905.02244>`_.

    Args:
        backbone (nn.Module): the network used to compute the features for the model.
            The backbone should return an OrderedDict[Tensor], with the key being
            "high" for the high level feature map and "low" for the low level feature map.
        low_channels (int): the number of channels of the low level features.
        high_channels (int): the number of channels of the high level features.
31
        num_classes (int, optional): number of output classes of the model (including the background).
32
33
34
        inter_channels (int, optional): the number of channels for intermediate computations.
    """

35
    def __init__(
36
        self, backbone: nn.Module, low_channels: int, high_channels: int, num_classes: int, inter_channels: int = 128
37
    ) -> None:
38
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
39
        _log_api_usage_once(self)
40
41
42
        self.backbone = backbone
        self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels)

43
    def forward(self, input: Tensor) -> Dict[str, Tensor]:
44
45
        features = self.backbone(input)
        out = self.classifier(features)
46
        out = F.interpolate(out, size=input.shape[-2:], mode="bilinear", align_corners=False)
47
48
49
50
51
52
53
54

        result = OrderedDict()
        result["out"] = out

        return result


class LRASPPHead(nn.Module):
55
    def __init__(self, low_channels: int, high_channels: int, num_classes: int, inter_channels: int) -> None:
56
57
58
59
        super().__init__()
        self.cbr = nn.Sequential(
            nn.Conv2d(high_channels, inter_channels, 1, bias=False),
            nn.BatchNorm2d(inter_channels),
60
            nn.ReLU(inplace=True),
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        )
        self.scale = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(high_channels, inter_channels, 1, bias=False),
            nn.Sigmoid(),
        )
        self.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
        self.high_classifier = nn.Conv2d(inter_channels, num_classes, 1)

    def forward(self, input: Dict[str, Tensor]) -> Tensor:
        low = input["low"]
        high = input["high"]

        x = self.cbr(high)
        s = self.scale(high)
        x = x * s
77
        x = F.interpolate(x, size=low.shape[-2:], mode="bilinear", align_corners=False)
78
79

        return self.low_classifier(low) + self.high_classifier(x)
80
81


82
def _lraspp_mobilenetv3(backbone: MobileNetV3, num_classes: int) -> LRASPP:
83
84
85
86
87
88
89
90
    backbone = backbone.features
    # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
    # The first and last blocks are always included because they are the C0 (conv1) and Cn.
    stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
    low_pos = stage_indices[-4]  # use C2 here which has output_stride = 8
    high_pos = stage_indices[-1]  # use C5 which has output_stride = 16
    low_channels = backbone[low_pos].out_channels
    high_channels = backbone[high_pos].out_channels
91
    backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"})
92
93
94
95

    return LRASPP(backbone, low_channels, high_channels, num_classes)


96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
    COCO_WITH_VOC_LABELS_V1 = Weights(
        url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
        transforms=partial(SemanticSegmentation, resize_size=520),
        meta={
            "task": "image_semantic_segmentation",
            "architecture": "LRASPP",
            "num_params": 3221538,
            "categories": _VOC_CATEGORIES,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large",
            "mIoU": 57.9,
            "acc": 91.2,
        },
    )
    DEFAULT = COCO_WITH_VOC_LABELS_V1


@handle_legacy_interface(
    weights=("pretrained", LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1),
    weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
117
def lraspp_mobilenet_v3_large(
118
119
    *,
    weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None,
120
    progress: bool = True,
121
122
    num_classes: Optional[int] = None,
    weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
123
124
125
126
127
    **kwargs: Any,
) -> LRASPP:
    """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.

    Args:
128
        weights (LRASPP_MobileNet_V3_Large_Weights, optional): The pretrained weights for the model
129
        progress (bool): If True, displays a progress bar of the download to stderr
130
131
        num_classes (int, optional): number of output classes of the model (including the background)
        weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone
132
133
134
135
    """
    if kwargs.pop("aux_loss", False):
        raise NotImplementedError("This model does not use auxiliary loss")

136
137
138
139
140
141
142
143
144
145
    weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights)
    weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)

    if weights is not None:
        weights_backbone = None
        num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
    elif num_classes is None:
        num_classes = 21

    backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
146
147
    model = _lraspp_mobilenetv3(backbone, num_classes)

148
149
150
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

151
    return model