Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
1cb85abe
Unverified
Commit
1cb85abe
authored
Sep 07, 2021
by
Alexander Soare
Committed by
GitHub
Sep 07, 2021
Browse files
fix fx example so that it works with MaskRCNN (#4376)
parent
fbd69f10
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
7 deletions
+10
-7
docs/source/feature_extraction.rst
docs/source/feature_extraction.rst
+10
-7
No files found.
docs/source/feature_extraction.rst
View file @
1cb85abe
...
...
@@ -39,6 +39,7 @@ Here is an example of how we might extract features for MaskRCNN:
from
torchvision
.
models
.
feature_extraction
import
get_graph_node_names
from
torchvision
.
models
.
feature_extraction
import
create_feature_extractor
from
torchvision
.
models
.
detection
.
mask_rcnn
import
MaskRCNN
from
torchvision
.
models
.
detection
.
backbone_utils
import
LastLevelMaxPool
from
torchvision
.
ops
.
feature_pyramid_network
import
FeaturePyramidNetwork
...
...
@@ -94,10 +95,11 @@ Here is an example of how we might extract features for MaskRCNN:
super(Resnet50WithFPN, self).__init__()
# Get a resnet50 backbone
m = resnet50()
# Extract 4 main layers (note:
you can also provide a list for return
#
nodes if the keys and the values are the same
)
# Extract 4 main layers (note:
MaskRCNN needs this particular name
#
mapping for return nodes
)
self.body = create_feature_extractor(
m, return_nodes=['
layer1
', '
layer2
', '
layer3
', '
layer4
'])
m, return_nodes={f'
layer
{
k
}
': str(v)
for v, k in enumerate([1, 2, 3, 4])})
# Dry run to get number of channels for FPN
inp = torch.randn(2, 3, 224, 224)
with torch.no_grad():
...
...
@@ -106,7 +108,8 @@ Here is an example of how we might extract features for MaskRCNN:
# Build FPN
self.out_channels = 256
self.fpn = FeaturePyramidNetwork(
in_channels_list, out_channels=self.out_channels)
in_channels_list, out_channels=self.out_channels,
extra_blocks=LastLevelMaxPool())
def forward(self, x):
x = self.body(x)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment