Unverified Commit 72d650ae authored by Alexander Soare's avatar Alexander Soare Committed by GitHub
Browse files

Add FX feature extraction as an alternative to intermediate_layer_getter (#4302)



* add fx feature extraction util

* Make it possible to use train and eval mode

* FX feature extraction - Tweaks and small bug fixes

* FX feature extraction - add tests

* move to feature_extraction.py, add LeafModuleAwareTracer, add docs

* Tweaks to docs

* addressing latest round of feedback

* undo line spacing changes

* change type hints in docstrings

* fix sphinx indentation

* expose feature_extraction

* add maskrcnn example

* add api refernce subheading

* address latest review notes, refactor names, fix regex, cosmetics

* Add back efficientnet to models

* fix tests for effnet

* fix linting issue

* fix test tracer kwargs
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 981ccfdf
torchvision.models.feature_extraction
=====================================
.. currentmodule:: torchvision.models.feature_extraction
Feature extraction utilities let us tap into our models to access intermediate
transformations of our inputs. This could be useful for a variety of
applications in computer vision. Just a few examples are:
- Visualizing feature maps.
- Extracting features to compute image descriptors for tasks like facial
recognition, copy-detection, or image retrieval.
- Passing selected features to downstream sub-networks for end-to-end training
with a specific task in mind. For example, passing a hierarchy of features
to a Feature Pyramid Network with object detection heads.
Torchvision provides :func:`create_feature_extractor` for this purpose.
It works by following roughly these steps:
1. Symbolically tracing the model to get a graphical representation of
how it transforms the input, step by step.
2. Setting the user-selected graph nodes as ouputs.
3. Removing all redundant nodes (anything downstream of the ouput nodes).
4. Generating python code from the resulting graph and bundling that into a
PyTorch module together with the graph itself.
|
The `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_
provides a more general and detailed explanation of the above procedure and
the inner workings of the symbolic tracing.
Here is an example of how we might extract features for MaskRCNN:
.. code-block:: python
import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
# To assist you in designing the feature extractor you may want to print out
# the available nodes for resnet50.
m = resnet50()
train_nodes, eval_nodes = get_graph_node_names(resnet50())
# The lists returned, are the names of all the graph nodes (in order of
# execution) for the input model traced in train mode and in eval mode
# respectively. You'll find that `train_nodes` and `eval_nodes` are the same
# for this example. But if the model contains control flow that's dependent
# on the training mode, they may be different.
# To specify the nodes you want to extract, you could select the final node
# that appears in each of the main layers:
return_nodes = {
# node_name: user-specified key for output dict
'layer1.2.relu_2': 'layer1',
'layer2.3.relu_2': 'layer2',
'layer3.5.relu_2': 'layer3',
'layer4.2.relu_2': 'layer4',
}
# But `create_feature_extractor` can also accept truncated node specifications
# like "layer1", as it will just pick the last node that's a descendent of
# of the specification. (Tip: be careful with this, especially when a layer
# has multiple outputs. It's not always guaranteed that the last operation
# performed is the one that corresponds to the output you desire. You should
# consult the source code for the input model to confirm.)
return_nodes = {
'layer1': 'layer1',
'layer2': 'layer2',
'layer3': 'layer3',
'layer4': 'layer4',
}
# Now you can build the feature extractor. This returns a module whose forward
# method returns a dictionary like:
# {
# 'layer1': ouput of layer 1,
# 'layer2': ouput of layer 2,
# 'layer3': ouput of layer 3,
# 'layer4': ouput of layer 4,
# }
create_feature_extractor(m, return_nodes=return_nodes)
# Let's put all that together to wrap resnet50 with MaskRCNN
# MaskRCNN requires a backbone with an attached FPN
class Resnet50WithFPN(torch.nn.Module):
def __init__(self):
super(Resnet50WithFPN, self).__init__()
# Get a resnet50 backbone
m = resnet50()
# Extract 4 main layers (note: you can also provide a list for return
# nodes if the keys and the values are the same)
self.body = create_feature_extractor(
m, return_nodes=['layer1', 'layer2', 'layer3', 'layer4'])
# Dry run to get number of channels for FPN
inp = torch.randn(2, 3, 224, 224)
with torch.no_grad():
out = self.body(inp)
in_channels_list = [o.shape[1] for o in out.values()]
# Build FPN
self.out_channels = 256
self.fpn = FeaturePyramidNetwork(
in_channels_list, out_channels=self.out_channels)
def forward(self, x):
x = self.body(x)
x = self.fpn(x)
return x
# Now we can build our model!
model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval()
API Reference
-------------
.. autofunction:: create_feature_extractor
.. autofunction:: get_graph_node_names
\ No newline at end of file
...@@ -34,6 +34,7 @@ architectures, and common image transformations for computer vision. ...@@ -34,6 +34,7 @@ architectures, and common image transformations for computer vision.
datasets datasets
io io
models models
feature_extraction
ops ops
transforms transforms
utils utils
......
from functools import partial
from itertools import chain
import random
import torch import torch
from torchvision import models
import torchvision
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models._utils import IntermediateLayerGetter
import pytest import pytest
from common_utils import set_rng_seed
def get_available_models():
# TODO add a registration mechanism to torchvision.models
return [k for k, v in models.__dict__.items()
if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
@pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50')) @pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50'))
def test_resnet_fpn_backbone(backbone_name): def test_resnet_fpn_backbone(backbone_name):
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu') x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu')
y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x) y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
assert list(y.keys()) == ['0', '1', '2', '3', 'pool'] assert list(y.keys()) == ['0', '1', '2', '3', 'pool']
# Needed by TestFxFeatureExtraction.test_leaf_module_and_function
def leaf_function(x):
return int(x)
class TestFxFeatureExtraction:
inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device='cpu')
model_defaults = {
'num_classes': 1,
'pretrained': False
}
leaf_modules = [torchvision.ops.StochasticDepth]
def _create_feature_extractor(self, *args, **kwargs):
"""
Apply leaf modules
"""
tracer_kwargs = {}
if 'tracer_kwargs' not in kwargs:
tracer_kwargs = {'leaf_modules': self.leaf_modules}
else:
tracer_kwargs = kwargs.pop('tracer_kwargs')
return create_feature_extractor(
*args, **kwargs,
tracer_kwargs=tracer_kwargs,
suppress_diff_warning=True)
def _get_return_nodes(self, model):
set_rng_seed(0)
exclude_nodes_filter = ['getitem', 'floordiv', 'size', 'chunk']
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': self.leaf_modules},
suppress_diff_warning=True)
# Get rid of any nodes that don't return tensors as they cause issues
# when testing backward pass.
train_nodes = [n for n in train_nodes
if not any(x in n for x in exclude_nodes_filter)]
eval_nodes = [n for n in eval_nodes
if not any(x in n for x in exclude_nodes_filter)]
return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)
@pytest.mark.parametrize('model_name', get_available_models())
def test_build_fx_feature_extractor(self, model_name):
set_rng_seed(0)
model = models.__dict__[model_name](**self.model_defaults).eval()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
# Check that it works with both a list and dict for return nodes
self._create_feature_extractor(
model, train_return_nodes={v: v for v in train_return_nodes},
eval_return_nodes=eval_return_nodes)
self._create_feature_extractor(
model, train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes)
# Check must specify return nodes
with pytest.raises(AssertionError):
self._create_feature_extractor(model)
# Check return_nodes and train_return_nodes / eval_return nodes
# mutual exclusivity
with pytest.raises(AssertionError):
self._create_feature_extractor(
model, return_nodes=train_return_nodes,
train_return_nodes=train_return_nodes)
# Check train_return_nodes / eval_return nodes must both be specified
with pytest.raises(AssertionError):
self._create_feature_extractor(
model, train_return_nodes=train_return_nodes)
# Check invalid node name raises ValueError
with pytest.raises(ValueError):
# First just double check that this node really doesn't exist
if not any(n.startswith('l') or n.startswith('l.') for n
in chain(train_return_nodes, eval_return_nodes)):
self._create_feature_extractor(
model, train_return_nodes=['l'], eval_return_nodes=['l'])
else: # otherwise skip this check
raise ValueError
@pytest.mark.parametrize('model_name', get_available_models())
def test_forward_backward(self, model_name):
model = models.__dict__[model_name](**self.model_defaults).train()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes)
out = model(self.inp)
sum([o.mean() for o in out.values()]).backward()
def test_feature_extraction_methods_equivalence(self):
model = models.resnet18(**self.model_defaults).eval()
return_layers = {
'layer1': 'layer1',
'layer2': 'layer2',
'layer3': 'layer3',
'layer4': 'layer4'
}
ilg_model = IntermediateLayerGetter(
model, return_layers).eval()
fx_model = self._create_feature_extractor(model, return_layers)
# Check that we have same parameters
for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(),
fx_model.named_parameters()):
assert n1 == n2
assert p1.equal(p2)
# And that ouputs match
with torch.no_grad():
ilg_out = ilg_model(self.inp)
fgn_out = fx_model(self.inp)
assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys()))
for k in ilg_out.keys():
assert ilg_out[k].equal(fgn_out[k])
@pytest.mark.parametrize('model_name', get_available_models())
def test_jit_forward_backward(self, model_name):
set_rng_seed(0)
model = models.__dict__[model_name](**self.model_defaults).train()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes)
model = torch.jit.script(model)
fgn_out = model(self.inp)
sum([o.mean() for o in fgn_out.values()]).backward()
def test_train_eval(self):
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.dropout = torch.nn.Dropout(p=1.)
def forward(self, x):
x = x.mean()
x = self.dropout(x) # dropout
if self.training:
x += 100 # add
else:
x *= 0 # mul
x -= 0 # sub
return x
model = TestModel()
train_return_nodes = ['dropout', 'add', 'sub']
eval_return_nodes = ['dropout', 'mul', 'sub']
def checks(model, mode):
with torch.no_grad():
out = model(torch.ones(10, 10))
if mode == 'train':
# Check that dropout is respected
assert out['dropout'].item() == 0
# Check that control flow dependent on training_mode is respected
assert out['sub'].item() == 100
assert 'add' in out
assert 'mul' not in out
elif mode == 'eval':
# Check that dropout is respected
assert out['dropout'].item() == 1
# Check that control flow dependent on training_mode is respected
assert out['sub'].item() == 0
assert 'mul' in out
assert 'add' not in out
# Starting from train mode
model.train()
fx_model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes)
# Check that the models stay in their original training state
assert model.training
assert fx_model.training
# Check outputs
checks(fx_model, 'train')
# Check outputs after switching to eval mode
fx_model.eval()
checks(fx_model, 'eval')
# Starting from eval mode
model.eval()
fx_model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes)
# Check that the models stay in their original training state
assert not model.training
assert not fx_model.training
# Check outputs
checks(fx_model, 'eval')
# Check outputs after switching to train mode
fx_model.train()
checks(fx_model, 'train')
def test_leaf_module_and_function(self):
class LeafModule(torch.nn.Module):
def forward(self, x):
# This would raise a TypeError if it were not in a leaf module
int(x.shape[0])
return torch.nn.functional.relu(x + 4)
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 1, 3)
self.leaf_module = LeafModule()
def forward(self, x):
leaf_function(x.shape[0])
x = self.conv(x)
return self.leaf_module(x)
model = self._create_feature_extractor(
TestModule(), return_nodes=['leaf_module'],
tracer_kwargs={'leaf_modules': [LeafModule],
'autowrap_functions': [leaf_function]}).train()
# Check that LeafModule is not in the list of nodes
assert 'relu' not in [str(n) for n in model.graph.nodes]
assert 'leaf_module' in [str(n) for n in model.graph.nodes]
# Check forward
out = model(self.inp)
# And backward
out['leaf_module'].mean().backward()
...@@ -13,3 +13,4 @@ from . import segmentation ...@@ -13,3 +13,4 @@ from . import segmentation
from . import detection from . import detection
from . import video from . import video
from . import quantization from . import quantization
from . import feature_extraction
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment