backbone_utils.py 2.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
from collections import OrderedDict
from torch import nn
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool

from torchvision.ops import misc as misc_nn_ops
from .._utils import IntermediateLayerGetter
from .. import resnet


eellison's avatar
eellison committed
10
class BackboneWithFPN(nn.Module):
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
    """
    Adds a FPN on top of a model.
    Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
    extract a submodel that returns the feature maps specified in return_layers.
    The same limitations of IntermediatLayerGetter apply here.
    Arguments:
        backbone (nn.Module)
        return_layers (Dict[name, new_name]): a dict containing the names
            of the modules for which the activations will be returned as
            the key of the dict, and the value of the dict is the name
            of the returned activation (which the user can specify).
        in_channels_list (List[int]): number of channels for each feature map
            that is returned, in the order they are present in the OrderedDict
        out_channels (int): number of channels in the FPN.
    Attributes:
        out_channels (int): the number of channels in the FPN
    """
28
    def __init__(self, backbone, return_layers, in_channels_list, out_channels):
eellison's avatar
eellison committed
29
30
31
        super(BackboneWithFPN, self).__init__()
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
        self.fpn = FeaturePyramidNetwork(
32
33
34
35
36
37
            in_channels_list=in_channels_list,
            out_channels=out_channels,
            extra_blocks=LastLevelMaxPool(),
        )
        self.out_channels = out_channels

eellison's avatar
eellison committed
38
39
40
41
42
    def forward(self, x):
        x = self.body(x)
        x = self.fpn(x)
        return x

43

44
def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.FrozenBatchNorm2d):
45
46
    backbone = resnet.__dict__[backbone_name](
        pretrained=pretrained,
47
        norm_layer=norm_layer)
48
49
50
51
52
    # freeze layers
    for name, parameter in backbone.named_parameters():
        if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
            parameter.requires_grad_(False)

eellison's avatar
eellison committed
53
    return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'}
54

55
    in_channels_stage2 = backbone.inplanes // 8
56
57
58
59
60
61
62
63
    in_channels_list = [
        in_channels_stage2,
        in_channels_stage2 * 2,
        in_channels_stage2 * 4,
        in_channels_stage2 * 8,
    ]
    out_channels = 256
    return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels)