backbone_utils.py 10.3 KB
Newer Older
1
import warnings
limm's avatar
limm committed
2
from typing import Callable, Dict, List, Optional, Union
3

limm's avatar
limm committed
4
from torch import nn, Tensor
5
from torchvision.ops import misc as misc_nn_ops
limm's avatar
limm committed
6
7
8
9
10
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool

from .. import mobilenet, resnet
from .._api import _get_enum_from_fn, WeightsEnum
from .._utils import handle_legacy_interface, IntermediateLayerGetter
11
12


eellison's avatar
eellison committed
13
class BackboneWithFPN(nn.Module):
14
15
16
17
    """
    Adds a FPN on top of a model.
    Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
    extract a submodel that returns the feature maps specified in return_layers.
limm's avatar
limm committed
18
    The same limitations of IntermediateLayerGetter apply here.
19
    Args:
20
21
22
23
24
25
26
27
        backbone (nn.Module)
        return_layers (Dict[name, new_name]): a dict containing the names
            of the modules for which the activations will be returned as
            the key of the dict, and the value of the dict is the name
            of the returned activation (which the user can specify).
        in_channels_list (List[int]): number of channels for each feature map
            that is returned, in the order they are present in the OrderedDict
        out_channels (int): number of channels in the FPN.
limm's avatar
limm committed
28
        norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
29
30
31
    Attributes:
        out_channels (int): the number of channels in the FPN
    """
limm's avatar
limm committed
32
33
34
35
36
37
38
39
40
41
42

    def __init__(
        self,
        backbone: nn.Module,
        return_layers: Dict[str, str],
        in_channels_list: List[int],
        out_channels: int,
        extra_blocks: Optional[ExtraFPNBlock] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
43
44
45
46

        if extra_blocks is None:
            extra_blocks = LastLevelMaxPool()

eellison's avatar
eellison committed
47
48
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
        self.fpn = FeaturePyramidNetwork(
49
50
            in_channels_list=in_channels_list,
            out_channels=out_channels,
51
            extra_blocks=extra_blocks,
limm's avatar
limm committed
52
            norm_layer=norm_layer,
53
54
55
        )
        self.out_channels = out_channels

limm's avatar
limm committed
56
    def forward(self, x: Tensor) -> Dict[str, Tensor]:
eellison's avatar
eellison committed
57
58
59
60
        x = self.body(x)
        x = self.fpn(x)
        return x

61

limm's avatar
limm committed
62
63
64
65
66
67
@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
    ),
)
68
def resnet_fpn_backbone(
limm's avatar
limm committed
69
70
71
72
73
74
75
76
    *,
    backbone_name: str,
    weights: Optional[WeightsEnum],
    norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
    trainable_layers: int = 3,
    returned_layers: Optional[List[int]] = None,
    extra_blocks: Optional[ExtraFPNBlock] = None,
) -> BackboneWithFPN:
77
78
79
80
81
    """
    Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.

    Examples::

limm's avatar
limm committed
82
83
        >>> import torch
        >>> from torchvision.models import ResNet50_Weights
84
        >>> from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
limm's avatar
limm committed
85
        >>> backbone = resnet_fpn_backbone(backbone_name='resnet50', weights=ResNet50_Weights.DEFAULT, trainable_layers=3)
86
87
88
89
90
91
92
93
94
95
96
97
        >>> # get some dummy image
        >>> x = torch.rand(1,3,64,64)
        >>> # compute the output
        >>> output = backbone(x)
        >>> print([(k, v.shape) for k, v in output.items()])
        >>> # returns
        >>>   [('0', torch.Size([1, 256, 16, 16])),
        >>>    ('1', torch.Size([1, 256, 8, 8])),
        >>>    ('2', torch.Size([1, 256, 4, 4])),
        >>>    ('3', torch.Size([1, 256, 2, 2])),
        >>>    ('pool', torch.Size([1, 256, 1, 1]))]

98
    Args:
limm's avatar
limm committed
99
        backbone_name (string): resnet architecture. Possible values are 'resnet18', 'resnet34', 'resnet50',
100
             'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
limm's avatar
limm committed
101
102
        weights (WeightsEnum, optional): The pretrained weights for the model
        norm_layer (callable): it is recommended to use the default value. For details visit:
103
            (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
limm's avatar
limm committed
104
        trainable_layers (int): number of trainable (not frozen) layers starting from final block.
105
            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
106
        returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``.
limm's avatar
limm committed
107
            By default, all layers are returned.
108
109
110
111
        extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
            be performed. It is expected to take the fpn features, the original
            features and the names of the original features as input, and returns
            a new list of feature maps and their corresponding names. By
limm's avatar
limm committed
112
            default, a ``LastLevelMaxPool`` is used.
113
    """
limm's avatar
limm committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
    return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks)


def _resnet_fpn_extractor(
    backbone: resnet.ResNet,
    trainable_layers: int,
    returned_layers: Optional[List[int]] = None,
    extra_blocks: Optional[ExtraFPNBlock] = None,
    norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> BackboneWithFPN:

    # select layers that won't be frozen
    if trainable_layers < 0 or trainable_layers > 5:
        raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}")
    layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
130
    if trainable_layers == 5:
limm's avatar
limm committed
131
        layers_to_train.append("bn1")
132
    for name, parameter in backbone.named_parameters():
133
        if all([not name.startswith(layer) for layer in layers_to_train]):
134
135
            parameter.requires_grad_(False)

136
137
138
139
140
    if extra_blocks is None:
        extra_blocks = LastLevelMaxPool()

    if returned_layers is None:
        returned_layers = [1, 2, 3, 4]
limm's avatar
limm committed
141
142
143
    if min(returned_layers) <= 0 or max(returned_layers) >= 5:
        raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}")
    return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
144

145
    in_channels_stage2 = backbone.inplanes // 8
146
    in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
147
    out_channels = 256
limm's avatar
limm committed
148
149
150
151
152
153
154
155
156
157
158
159
160
    return BackboneWithFPN(
        backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
    )


def _validate_trainable_layers(
    is_trained: bool,
    trainable_backbone_layers: Optional[int],
    max_value: int,
    default_value: int,
) -> int:
    # don't freeze any layers if pretrained model or backbone is not used
    if not is_trained:
161
162
        if trainable_backbone_layers is not None:
            warnings.warn(
limm's avatar
limm committed
163
                "Changing trainable_backbone_layers has no effect if "
164
                "neither pretrained nor pretrained_backbone have been set to True, "
limm's avatar
limm committed
165
166
                f"falling back to trainable_backbone_layers={max_value} so that all layers are trainable"
            )
167
168
169
        trainable_backbone_layers = max_value

    # by default freeze first blocks
170
    if trainable_backbone_layers is None:
171
        trainable_backbone_layers = default_value
limm's avatar
limm committed
172
173
174
175
    if trainable_backbone_layers < 0 or trainable_backbone_layers > max_value:
        raise ValueError(
            f"Trainable backbone layers should be in the range [0,{max_value}], got {trainable_backbone_layers} "
        )
176
    return trainable_backbone_layers
177
178


limm's avatar
limm committed
179
180
181
182
183
184
@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
    ),
)
185
def mobilenet_backbone(
limm's avatar
limm committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    *,
    backbone_name: str,
    weights: Optional[WeightsEnum],
    fpn: bool,
    norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
    trainable_layers: int = 2,
    returned_layers: Optional[List[int]] = None,
    extra_blocks: Optional[ExtraFPNBlock] = None,
) -> nn.Module:
    backbone = mobilenet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
    return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks)


def _mobilenet_extractor(
    backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
    fpn: bool,
    trainable_layers: int,
    returned_layers: Optional[List[int]] = None,
    extra_blocks: Optional[ExtraFPNBlock] = None,
    norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> nn.Module:
    backbone = backbone.features
208
    # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
209
    # The first and last blocks are always included because they are the C0 (conv1) and Cn.
210
    stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
211
212
    num_stages = len(stage_indices)

limm's avatar
limm committed
213
214
215
    # find the index of the layer from which we won't freeze
    if trainable_layers < 0 or trainable_layers > num_stages:
        raise ValueError(f"Trainable layers should be in the range [0,{num_stages}], got {trainable_layers} ")
216
    freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
217
218
219
220
221
222
223
224
225
226
227
228

    for b in backbone[:freeze_before]:
        for parameter in b.parameters():
            parameter.requires_grad_(False)

    out_channels = 256
    if fpn:
        if extra_blocks is None:
            extra_blocks = LastLevelMaxPool()

        if returned_layers is None:
            returned_layers = [num_stages - 2, num_stages - 1]
limm's avatar
limm committed
229
230
231
        if min(returned_layers) < 0 or max(returned_layers) >= num_stages:
            raise ValueError(f"Each returned layer should be in the range [0,{num_stages - 1}], got {returned_layers} ")
        return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}
232
233

        in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
limm's avatar
limm committed
234
235
236
        return BackboneWithFPN(
            backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
        )
237
238
239
240
241
242
    else:
        m = nn.Sequential(
            backbone,
            # depthwise linear combination of channels to reduce their size
            nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
        )
limm's avatar
limm committed
243
        m.out_channels = out_channels  # type: ignore[assignment]
244
        return m