Unverified Commit 1cb85abe authored by Alexander Soare's avatar Alexander Soare Committed by GitHub
Browse files

fix fx example so that it works with MaskRCNN (#4376)

parent fbd69f10
...@@ -39,6 +39,7 @@ Here is an example of how we might extract features for MaskRCNN: ...@@ -39,6 +39,7 @@ Here is an example of how we might extract features for MaskRCNN:
from torchvision.models.feature_extraction import get_graph_node_names from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.mask_rcnn import MaskRCNN from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.models.detection.backbone_utils import LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
...@@ -94,10 +95,11 @@ Here is an example of how we might extract features for MaskRCNN: ...@@ -94,10 +95,11 @@ Here is an example of how we might extract features for MaskRCNN:
super(Resnet50WithFPN, self).__init__() super(Resnet50WithFPN, self).__init__()
# Get a resnet50 backbone # Get a resnet50 backbone
m = resnet50() m = resnet50()
# Extract 4 main layers (note: you can also provide a list for return # Extract 4 main layers (note: MaskRCNN needs this particular name
# nodes if the keys and the values are the same) # mapping for return nodes)
self.body = create_feature_extractor( self.body = create_feature_extractor(
m, return_nodes=['layer1', 'layer2', 'layer3', 'layer4']) m, return_nodes={f'layer{k}': str(v)
for v, k in enumerate([1, 2, 3, 4])})
# Dry run to get number of channels for FPN # Dry run to get number of channels for FPN
inp = torch.randn(2, 3, 224, 224) inp = torch.randn(2, 3, 224, 224)
with torch.no_grad(): with torch.no_grad():
...@@ -106,7 +108,8 @@ Here is an example of how we might extract features for MaskRCNN: ...@@ -106,7 +108,8 @@ Here is an example of how we might extract features for MaskRCNN:
# Build FPN # Build FPN
self.out_channels = 256 self.out_channels = 256
self.fpn = FeaturePyramidNetwork( self.fpn = FeaturePyramidNetwork(
in_channels_list, out_channels=self.out_channels) in_channels_list, out_channels=self.out_channels,
extra_blocks=LastLevelMaxPool())
def forward(self, x): def forward(self, x):
x = self.body(x) x = self.body(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment