Unverified Commit 1cb85abe authored by Alexander Soare's avatar Alexander Soare Committed by GitHub
Browse files

fix fx example so that it works with MaskRCNN (#4376)

parent fbd69f10
......@@ -39,6 +39,7 @@ Here is an example of how we might extract features for MaskRCNN:
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.models.detection.backbone_utils import LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
......@@ -94,10 +95,11 @@ Here is an example of how we might extract features for MaskRCNN:
super(Resnet50WithFPN, self).__init__()
# Get a resnet50 backbone
m = resnet50()
# Extract 4 main layers (note: you can also provide a list for return
# nodes if the keys and the values are the same)
# Extract 4 main layers (note: MaskRCNN needs this particular name
# mapping for return nodes)
self.body = create_feature_extractor(
m, return_nodes=['layer1', 'layer2', 'layer3', 'layer4'])
m, return_nodes={f'layer{k}': str(v)
for v, k in enumerate([1, 2, 3, 4])})
# Dry run to get number of channels for FPN
inp = torch.randn(2, 3, 224, 224)
with torch.no_grad():
......@@ -106,7 +108,8 @@ Here is an example of how we might extract features for MaskRCNN:
# Build FPN
self.out_channels = 256
self.fpn = FeaturePyramidNetwork(
in_channels_list, out_channels=self.out_channels)
in_channels_list, out_channels=self.out_channels,
extra_blocks=LastLevelMaxPool())
def forward(self, x):
x = self.body(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment