Unverified Commit 1cb85abe authored by Alexander Soare's avatar Alexander Soare Committed by GitHub
Browse files

fix fx example so that it works with MaskRCNN (#4376)

parent fbd69f10
...@@ -39,6 +39,7 @@ Here is an example of how we might extract features for MaskRCNN: ...@@ -39,6 +39,7 @@ Here is an example of how we might extract features for MaskRCNN:
from torchvision.models.feature_extraction import get_graph_node_names from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.mask_rcnn import MaskRCNN from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.models.detection.backbone_utils import LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
...@@ -57,7 +58,7 @@ Here is an example of how we might extract features for MaskRCNN: ...@@ -57,7 +58,7 @@ Here is an example of how we might extract features for MaskRCNN:
# that appears in each of the main layers: # that appears in each of the main layers:
return_nodes = { return_nodes = {
# node_name: user-specified key for output dict # node_name: user-specified key for output dict
'layer1.2.relu_2': 'layer1', 'layer1.2.relu_2': 'layer1',
'layer2.3.relu_2': 'layer2', 'layer2.3.relu_2': 'layer2',
'layer3.5.relu_2': 'layer3', 'layer3.5.relu_2': 'layer3',
'layer4.2.relu_2': 'layer4', 'layer4.2.relu_2': 'layer4',
...@@ -70,7 +71,7 @@ Here is an example of how we might extract features for MaskRCNN: ...@@ -70,7 +71,7 @@ Here is an example of how we might extract features for MaskRCNN:
# performed is the one that corresponds to the output you desire. You should # performed is the one that corresponds to the output you desire. You should
# consult the source code for the input model to confirm.) # consult the source code for the input model to confirm.)
return_nodes = { return_nodes = {
'layer1': 'layer1', 'layer1': 'layer1',
'layer2': 'layer2', 'layer2': 'layer2',
'layer3': 'layer3', 'layer3': 'layer3',
'layer4': 'layer4', 'layer4': 'layer4',
...@@ -79,7 +80,7 @@ Here is an example of how we might extract features for MaskRCNN: ...@@ -79,7 +80,7 @@ Here is an example of how we might extract features for MaskRCNN:
# Now you can build the feature extractor. This returns a module whose forward # Now you can build the feature extractor. This returns a module whose forward
# method returns a dictionary like: # method returns a dictionary like:
# { # {
# 'layer1': ouput of layer 1, # 'layer1': ouput of layer 1,
# 'layer2': ouput of layer 2, # 'layer2': ouput of layer 2,
# 'layer3': ouput of layer 3, # 'layer3': ouput of layer 3,
# 'layer4': ouput of layer 4, # 'layer4': ouput of layer 4,
...@@ -94,10 +95,11 @@ Here is an example of how we might extract features for MaskRCNN: ...@@ -94,10 +95,11 @@ Here is an example of how we might extract features for MaskRCNN:
super(Resnet50WithFPN, self).__init__() super(Resnet50WithFPN, self).__init__()
# Get a resnet50 backbone # Get a resnet50 backbone
m = resnet50() m = resnet50()
# Extract 4 main layers (note: you can also provide a list for return # Extract 4 main layers (note: MaskRCNN needs this particular name
# nodes if the keys and the values are the same) # mapping for return nodes)
self.body = create_feature_extractor( self.body = create_feature_extractor(
m, return_nodes=['layer1', 'layer2', 'layer3', 'layer4']) m, return_nodes={f'layer{k}': str(v)
for v, k in enumerate([1, 2, 3, 4])})
# Dry run to get number of channels for FPN # Dry run to get number of channels for FPN
inp = torch.randn(2, 3, 224, 224) inp = torch.randn(2, 3, 224, 224)
with torch.no_grad(): with torch.no_grad():
...@@ -106,7 +108,8 @@ Here is an example of how we might extract features for MaskRCNN: ...@@ -106,7 +108,8 @@ Here is an example of how we might extract features for MaskRCNN:
# Build FPN # Build FPN
self.out_channels = 256 self.out_channels = 256
self.fpn = FeaturePyramidNetwork( self.fpn = FeaturePyramidNetwork(
in_channels_list, out_channels=self.out_channels) in_channels_list, out_channels=self.out_channels,
extra_blocks=LastLevelMaxPool())
def forward(self, x): def forward(self, x):
x = self.body(x) x = self.body(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment