backbone_utils.py 10.2 KB
Newer Older
1
import warnings
2
from typing import Callable, Dict, List, Optional, Union
3

4
from torch import nn, Tensor
5
from torchvision.ops import misc as misc_nn_ops
6
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
7

8
from .. import mobilenet, resnet
9
10
from .._api import _get_enum_from_fn, WeightsEnum
from .._utils import handle_legacy_interface, IntermediateLayerGetter
11
12


eellison's avatar
eellison committed
13
class BackboneWithFPN(nn.Module):
14
15
16
17
    """
    Adds a FPN on top of a model.
    Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
    extract a submodel that returns the feature maps specified in return_layers.
18
    The same limitations of IntermediateLayerGetter apply here.
19
    Args:
20
21
22
23
24
25
26
27
        backbone (nn.Module)
        return_layers (Dict[name, new_name]): a dict containing the names
            of the modules for which the activations will be returned as
            the key of the dict, and the value of the dict is the name
            of the returned activation (which the user can specify).
        in_channels_list (List[int]): number of channels for each feature map
            that is returned, in the order they are present in the OrderedDict
        out_channels (int): number of channels in the FPN.
28
        norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
29
30
31
    Attributes:
        out_channels (int): the number of channels in the FPN
    """
32

33
34
35
36
37
38
39
    def __init__(
        self,
        backbone: nn.Module,
        return_layers: Dict[str, str],
        in_channels_list: List[int],
        out_channels: int,
        extra_blocks: Optional[ExtraFPNBlock] = None,
40
        norm_layer: Optional[Callable[..., nn.Module]] = None,
41
    ) -> None:
42
        super().__init__()
43
44
45
46

        if extra_blocks is None:
            extra_blocks = LastLevelMaxPool()

eellison's avatar
eellison committed
47
48
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
        self.fpn = FeaturePyramidNetwork(
49
50
            in_channels_list=in_channels_list,
            out_channels=out_channels,
51
            extra_blocks=extra_blocks,
52
            norm_layer=norm_layer,
53
54
55
        )
        self.out_channels = out_channels

56
    def forward(self, x: Tensor) -> Dict[str, Tensor]:
eellison's avatar
eellison committed
57
58
59
60
        x = self.body(x)
        x = self.fpn(x)
        return x

61

62
63
64
@handle_legacy_interface(
    weights=(
        "pretrained",
65
        lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
66
67
    ),
)
68
def resnet_fpn_backbone(
69
    *,
70
    backbone_name: str,
71
    weights: Optional[WeightsEnum],
72
73
74
75
76
    norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
    trainable_layers: int = 3,
    returned_layers: Optional[List[int]] = None,
    extra_blocks: Optional[ExtraFPNBlock] = None,
) -> BackboneWithFPN:
77
78
79
80
81
82
    """
    Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.

    Examples::

        >>> from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
83
        >>> backbone = resnet_fpn_backbone('resnet50', weights=ResNet50_Weights.DEFAULT, trainable_layers=3)
84
85
86
87
88
89
90
91
92
93
94
95
        >>> # get some dummy image
        >>> x = torch.rand(1,3,64,64)
        >>> # compute the output
        >>> output = backbone(x)
        >>> print([(k, v.shape) for k, v in output.items()])
        >>> # returns
        >>>   [('0', torch.Size([1, 256, 16, 16])),
        >>>    ('1', torch.Size([1, 256, 8, 8])),
        >>>    ('2', torch.Size([1, 256, 4, 4])),
        >>>    ('3', torch.Size([1, 256, 2, 2])),
        >>>    ('pool', torch.Size([1, 256, 1, 1]))]

96
    Args:
97
        backbone_name (string): resnet architecture. Possible values are 'resnet18', 'resnet34', 'resnet50',
98
             'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
99
        weights (WeightsEnum, optional): The pretrained weights for the model
100
        norm_layer (callable): it is recommended to use the default value. For details visit:
101
            (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
102
        trainable_layers (int): number of trainable (not frozen) layers starting from final block.
103
            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
104
        returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``.
105
            By default, all layers are returned.
106
107
108
109
        extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
            be performed. It is expected to take the fpn features, the original
            features and the names of the original features as input, and returns
            a new list of feature maps and their corresponding names. By
110
            default, a ``LastLevelMaxPool`` is used.
111
    """
112
    backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
113
    return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks)
114

115

116
def _resnet_fpn_extractor(
117
118
    backbone: resnet.ResNet,
    trainable_layers: int,
119
120
    returned_layers: Optional[List[int]] = None,
    extra_blocks: Optional[ExtraFPNBlock] = None,
121
    norm_layer: Optional[Callable[..., nn.Module]] = None,
122
123
) -> BackboneWithFPN:

124
    # select layers that won't be frozen
125
126
    if trainable_layers < 0 or trainable_layers > 5:
        raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}")
127
    layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
128
    if trainable_layers == 5:
129
        layers_to_train.append("bn1")
130
    for name, parameter in backbone.named_parameters():
131
        if all([not name.startswith(layer) for layer in layers_to_train]):
132
133
            parameter.requires_grad_(False)

134
135
136
137
138
    if extra_blocks is None:
        extra_blocks = LastLevelMaxPool()

    if returned_layers is None:
        returned_layers = [1, 2, 3, 4]
139
140
    if min(returned_layers) <= 0 or max(returned_layers) >= 5:
        raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}")
141
    return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
142

143
    in_channels_stage2 = backbone.inplanes // 8
144
    in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
145
    out_channels = 256
146
147
148
    return BackboneWithFPN(
        backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
    )
149
150


151
def _validate_trainable_layers(
152
    is_trained: bool,
153
154
155
156
157
    trainable_backbone_layers: Optional[int],
    max_value: int,
    default_value: int,
) -> int:
    # don't freeze any layers if pretrained model or backbone is not used
158
    if not is_trained:
159
160
161
162
        if trainable_backbone_layers is not None:
            warnings.warn(
                "Changing trainable_backbone_layers has not effect if "
                "neither pretrained nor pretrained_backbone have been set to True, "
163
                f"falling back to trainable_backbone_layers={max_value} so that all layers are trainable"
164
            )
165
166
167
        trainable_backbone_layers = max_value

    # by default freeze first blocks
168
    if trainable_backbone_layers is None:
169
        trainable_backbone_layers = default_value
170
171
172
173
    if trainable_backbone_layers < 0 or trainable_backbone_layers > max_value:
        raise ValueError(
            f"Trainable backbone layers should be in the range [0,{max_value}], got {trainable_backbone_layers} "
        )
174
    return trainable_backbone_layers
175
176


177
178
179
@handle_legacy_interface(
    weights=(
        "pretrained",
180
        lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
181
182
    ),
)
183
def mobilenet_backbone(
184
    *,
185
    backbone_name: str,
186
    weights: Optional[WeightsEnum],
187
188
189
190
191
192
    fpn: bool,
    norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
    trainable_layers: int = 2,
    returned_layers: Optional[List[int]] = None,
    extra_blocks: Optional[ExtraFPNBlock] = None,
) -> nn.Module:
193
    backbone = mobilenet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
194
    return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks)
195

196

197
198
199
def _mobilenet_extractor(
    backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
    fpn: bool,
200
    trainable_layers: int,
201
202
    returned_layers: Optional[List[int]] = None,
    extra_blocks: Optional[ExtraFPNBlock] = None,
203
    norm_layer: Optional[Callable[..., nn.Module]] = None,
204
205
) -> nn.Module:
    backbone = backbone.features
206
    # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
207
    # The first and last blocks are always included because they are the C0 (conv1) and Cn.
208
    stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
209
210
    num_stages = len(stage_indices)

211
    # find the index of the layer from which we won't freeze
212
213
    if trainable_layers < 0 or trainable_layers > num_stages:
        raise ValueError(f"Trainable layers should be in the range [0,{num_stages}], got {trainable_layers} ")
214
    freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
215
216
217
218
219
220
221
222
223
224
225
226

    for b in backbone[:freeze_before]:
        for parameter in b.parameters():
            parameter.requires_grad_(False)

    out_channels = 256
    if fpn:
        if extra_blocks is None:
            extra_blocks = LastLevelMaxPool()

        if returned_layers is None:
            returned_layers = [num_stages - 2, num_stages - 1]
227
228
        if min(returned_layers) < 0 or max(returned_layers) >= num_stages:
            raise ValueError(f"Each returned layer should be in the range [0,{num_stages - 1}], got {returned_layers} ")
229
        return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}
230
231

        in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
232
233
234
        return BackboneWithFPN(
            backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
        )
235
236
237
238
239
240
    else:
        m = nn.Sequential(
            backbone,
            # depthwise linear combination of channels to reduce their size
            nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
        )
241
        m.out_channels = out_channels  # type: ignore[assignment]
242
        return m