from collections import OrderedDict from torch import nn from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool from torchvision.ops import misc as misc_nn_ops from .._utils import IntermediateLayerGetter from .. import resnet class BackboneWithFPN(nn.Sequential): def __init__(self, backbone, return_layers, in_channels_list, out_channels): body = IntermediateLayerGetter(backbone, return_layers=return_layers) fpn = FeaturePyramidNetwork( in_channels_list=in_channels_list, out_channels=out_channels, extra_blocks=LastLevelMaxPool(), ) super(BackboneWithFPN, self).__init__(OrderedDict( [("body", body), ("fpn", fpn)])) self.out_channels = out_channels def resnet_fpn_backbone(backbone_name, pretrained): backbone = resnet.__dict__[backbone_name]( pretrained=pretrained, norm_layer=misc_nn_ops.FrozenBatchNorm2d) # freeze layers for name, parameter in backbone.named_parameters(): if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: parameter.requires_grad_(False) return_layers = {'layer1': 0, 'layer2': 1, 'layer3': 2, 'layer4': 3} in_channels_stage2 = 256 in_channels_list = [ in_channels_stage2, in_channels_stage2 * 2, in_channels_stage2 * 4, in_channels_stage2 * 8, ] out_channels = 256 return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels)