Unverified Commit 72d650ae authored by Alexander Soare's avatar Alexander Soare Committed by GitHub
Browse files

Add FX feature extraction as an alternative to intermediate_layer_getter (#4302)



* add fx feature extraction util

* Make it possible to use train and eval mode

* FX feature extraction - Tweaks and small bug fixes

* FX feature extraction - add tests

* move to feature_extraction.py, add LeafModuleAwareTracer, add docs

* Tweaks to docs

* addressing latest round of feedback

* undo line spacing changes

* change type hints in docstrings

* fix sphinx indentation

* expose feature_extraction

* add maskrcnn example

* add api refernce subheading

* address latest review notes, refactor names, fix regex, cosmetics

* Add back efficientnet to models

* fix tests for effnet

* fix linting issue

* fix test tracer kwargs
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 981ccfdf
torchvision.models.feature_extraction
=====================================
.. currentmodule:: torchvision.models.feature_extraction
Feature extraction utilities let us tap into our models to access intermediate
transformations of our inputs. This could be useful for a variety of
applications in computer vision. Just a few examples are:
- Visualizing feature maps.
- Extracting features to compute image descriptors for tasks like facial
recognition, copy-detection, or image retrieval.
- Passing selected features to downstream sub-networks for end-to-end training
with a specific task in mind. For example, passing a hierarchy of features
to a Feature Pyramid Network with object detection heads.
Torchvision provides :func:`create_feature_extractor` for this purpose.
It works by following roughly these steps:
1. Symbolically tracing the model to get a graphical representation of
how it transforms the input, step by step.
2. Setting the user-selected graph nodes as ouputs.
3. Removing all redundant nodes (anything downstream of the ouput nodes).
4. Generating python code from the resulting graph and bundling that into a
PyTorch module together with the graph itself.
|
The `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_
provides a more general and detailed explanation of the above procedure and
the inner workings of the symbolic tracing.
Here is an example of how we might extract features for MaskRCNN:
.. code-block:: python
import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
# To assist you in designing the feature extractor you may want to print out
# the available nodes for resnet50.
m = resnet50()
train_nodes, eval_nodes = get_graph_node_names(resnet50())
# The lists returned, are the names of all the graph nodes (in order of
# execution) for the input model traced in train mode and in eval mode
# respectively. You'll find that `train_nodes` and `eval_nodes` are the same
# for this example. But if the model contains control flow that's dependent
# on the training mode, they may be different.
# To specify the nodes you want to extract, you could select the final node
# that appears in each of the main layers:
return_nodes = {
# node_name: user-specified key for output dict
'layer1.2.relu_2': 'layer1',
'layer2.3.relu_2': 'layer2',
'layer3.5.relu_2': 'layer3',
'layer4.2.relu_2': 'layer4',
}
# But `create_feature_extractor` can also accept truncated node specifications
# like "layer1", as it will just pick the last node that's a descendent of
# of the specification. (Tip: be careful with this, especially when a layer
# has multiple outputs. It's not always guaranteed that the last operation
# performed is the one that corresponds to the output you desire. You should
# consult the source code for the input model to confirm.)
return_nodes = {
'layer1': 'layer1',
'layer2': 'layer2',
'layer3': 'layer3',
'layer4': 'layer4',
}
# Now you can build the feature extractor. This returns a module whose forward
# method returns a dictionary like:
# {
# 'layer1': ouput of layer 1,
# 'layer2': ouput of layer 2,
# 'layer3': ouput of layer 3,
# 'layer4': ouput of layer 4,
# }
create_feature_extractor(m, return_nodes=return_nodes)
# Let's put all that together to wrap resnet50 with MaskRCNN
# MaskRCNN requires a backbone with an attached FPN
class Resnet50WithFPN(torch.nn.Module):
def __init__(self):
super(Resnet50WithFPN, self).__init__()
# Get a resnet50 backbone
m = resnet50()
# Extract 4 main layers (note: you can also provide a list for return
# nodes if the keys and the values are the same)
self.body = create_feature_extractor(
m, return_nodes=['layer1', 'layer2', 'layer3', 'layer4'])
# Dry run to get number of channels for FPN
inp = torch.randn(2, 3, 224, 224)
with torch.no_grad():
out = self.body(inp)
in_channels_list = [o.shape[1] for o in out.values()]
# Build FPN
self.out_channels = 256
self.fpn = FeaturePyramidNetwork(
in_channels_list, out_channels=self.out_channels)
def forward(self, x):
x = self.body(x)
x = self.fpn(x)
return x
# Now we can build our model!
model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval()
API Reference
-------------
.. autofunction:: create_feature_extractor
.. autofunction:: get_graph_node_names
\ No newline at end of file
......@@ -34,6 +34,7 @@ architectures, and common image transformations for computer vision.
datasets
io
models
feature_extraction
ops
transforms
utils
......
from functools import partial
from itertools import chain
import random
import torch
from torchvision import models
import torchvision
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models._utils import IntermediateLayerGetter
import pytest
from common_utils import set_rng_seed
def get_available_models():
# TODO add a registration mechanism to torchvision.models
return [k for k, v in models.__dict__.items()
if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
@pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50'))
def test_resnet_fpn_backbone(backbone_name):
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu')
y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
assert list(y.keys()) == ['0', '1', '2', '3', 'pool']
# Needed by TestFxFeatureExtraction.test_leaf_module_and_function
def leaf_function(x):
return int(x)
class TestFxFeatureExtraction:
inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device='cpu')
model_defaults = {
'num_classes': 1,
'pretrained': False
}
leaf_modules = [torchvision.ops.StochasticDepth]
def _create_feature_extractor(self, *args, **kwargs):
"""
Apply leaf modules
"""
tracer_kwargs = {}
if 'tracer_kwargs' not in kwargs:
tracer_kwargs = {'leaf_modules': self.leaf_modules}
else:
tracer_kwargs = kwargs.pop('tracer_kwargs')
return create_feature_extractor(
*args, **kwargs,
tracer_kwargs=tracer_kwargs,
suppress_diff_warning=True)
def _get_return_nodes(self, model):
set_rng_seed(0)
exclude_nodes_filter = ['getitem', 'floordiv', 'size', 'chunk']
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': self.leaf_modules},
suppress_diff_warning=True)
# Get rid of any nodes that don't return tensors as they cause issues
# when testing backward pass.
train_nodes = [n for n in train_nodes
if not any(x in n for x in exclude_nodes_filter)]
eval_nodes = [n for n in eval_nodes
if not any(x in n for x in exclude_nodes_filter)]
return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)
@pytest.mark.parametrize('model_name', get_available_models())
def test_build_fx_feature_extractor(self, model_name):
set_rng_seed(0)
model = models.__dict__[model_name](**self.model_defaults).eval()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
# Check that it works with both a list and dict for return nodes
self._create_feature_extractor(
model, train_return_nodes={v: v for v in train_return_nodes},
eval_return_nodes=eval_return_nodes)
self._create_feature_extractor(
model, train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes)
# Check must specify return nodes
with pytest.raises(AssertionError):
self._create_feature_extractor(model)
# Check return_nodes and train_return_nodes / eval_return nodes
# mutual exclusivity
with pytest.raises(AssertionError):
self._create_feature_extractor(
model, return_nodes=train_return_nodes,
train_return_nodes=train_return_nodes)
# Check train_return_nodes / eval_return nodes must both be specified
with pytest.raises(AssertionError):
self._create_feature_extractor(
model, train_return_nodes=train_return_nodes)
# Check invalid node name raises ValueError
with pytest.raises(ValueError):
# First just double check that this node really doesn't exist
if not any(n.startswith('l') or n.startswith('l.') for n
in chain(train_return_nodes, eval_return_nodes)):
self._create_feature_extractor(
model, train_return_nodes=['l'], eval_return_nodes=['l'])
else: # otherwise skip this check
raise ValueError
@pytest.mark.parametrize('model_name', get_available_models())
def test_forward_backward(self, model_name):
model = models.__dict__[model_name](**self.model_defaults).train()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes)
out = model(self.inp)
sum([o.mean() for o in out.values()]).backward()
def test_feature_extraction_methods_equivalence(self):
model = models.resnet18(**self.model_defaults).eval()
return_layers = {
'layer1': 'layer1',
'layer2': 'layer2',
'layer3': 'layer3',
'layer4': 'layer4'
}
ilg_model = IntermediateLayerGetter(
model, return_layers).eval()
fx_model = self._create_feature_extractor(model, return_layers)
# Check that we have same parameters
for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(),
fx_model.named_parameters()):
assert n1 == n2
assert p1.equal(p2)
# And that ouputs match
with torch.no_grad():
ilg_out = ilg_model(self.inp)
fgn_out = fx_model(self.inp)
assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys()))
for k in ilg_out.keys():
assert ilg_out[k].equal(fgn_out[k])
@pytest.mark.parametrize('model_name', get_available_models())
def test_jit_forward_backward(self, model_name):
set_rng_seed(0)
model = models.__dict__[model_name](**self.model_defaults).train()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes)
model = torch.jit.script(model)
fgn_out = model(self.inp)
sum([o.mean() for o in fgn_out.values()]).backward()
def test_train_eval(self):
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.dropout = torch.nn.Dropout(p=1.)
def forward(self, x):
x = x.mean()
x = self.dropout(x) # dropout
if self.training:
x += 100 # add
else:
x *= 0 # mul
x -= 0 # sub
return x
model = TestModel()
train_return_nodes = ['dropout', 'add', 'sub']
eval_return_nodes = ['dropout', 'mul', 'sub']
def checks(model, mode):
with torch.no_grad():
out = model(torch.ones(10, 10))
if mode == 'train':
# Check that dropout is respected
assert out['dropout'].item() == 0
# Check that control flow dependent on training_mode is respected
assert out['sub'].item() == 100
assert 'add' in out
assert 'mul' not in out
elif mode == 'eval':
# Check that dropout is respected
assert out['dropout'].item() == 1
# Check that control flow dependent on training_mode is respected
assert out['sub'].item() == 0
assert 'mul' in out
assert 'add' not in out
# Starting from train mode
model.train()
fx_model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes)
# Check that the models stay in their original training state
assert model.training
assert fx_model.training
# Check outputs
checks(fx_model, 'train')
# Check outputs after switching to eval mode
fx_model.eval()
checks(fx_model, 'eval')
# Starting from eval mode
model.eval()
fx_model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes)
# Check that the models stay in their original training state
assert not model.training
assert not fx_model.training
# Check outputs
checks(fx_model, 'eval')
# Check outputs after switching to train mode
fx_model.train()
checks(fx_model, 'train')
def test_leaf_module_and_function(self):
class LeafModule(torch.nn.Module):
def forward(self, x):
# This would raise a TypeError if it were not in a leaf module
int(x.shape[0])
return torch.nn.functional.relu(x + 4)
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 1, 3)
self.leaf_module = LeafModule()
def forward(self, x):
leaf_function(x.shape[0])
x = self.conv(x)
return self.leaf_module(x)
model = self._create_feature_extractor(
TestModule(), return_nodes=['leaf_module'],
tracer_kwargs={'leaf_modules': [LeafModule],
'autowrap_functions': [leaf_function]}).train()
# Check that LeafModule is not in the list of nodes
assert 'relu' not in [str(n) for n in model.graph.nodes]
assert 'leaf_module' in [str(n) for n in model.graph.nodes]
# Check forward
out = model(self.inp)
# And backward
out['leaf_module'].mean().backward()
......@@ -13,3 +13,4 @@ from . import segmentation
from . import detection
from . import video
from . import quantization
from . import feature_extraction
from typing import Dict, Callable, List, Union, Optional, Tuple
from collections import OrderedDict
import warnings
import re
from copy import deepcopy
from itertools import chain
import torch
from torch import nn
from torch import fx
from torch.fx.graph_module import _copy_attr
__all__ = ['create_feature_extractor', 'get_graph_node_names']
class LeafModuleAwareTracer(fx.Tracer):
"""
An fx.Tracer that allows the user to specify a set of leaf modules, ie.
modules that are not to be traced through. The resulting graph ends up
having single nodes referencing calls to the leaf modules' forward methods.
"""
def __init__(self, *args, **kwargs):
self.leaf_modules = {}
if 'leaf_modules' in kwargs:
leaf_modules = kwargs.pop('leaf_modules')
self.leaf_modules = leaf_modules
super(LeafModuleAwareTracer, self).__init__(*args, **kwargs)
def is_leaf_module(self, m: nn.Module, module_qualname: str) -> bool:
if isinstance(m, tuple(self.leaf_modules)):
return True
return super().is_leaf_module(m, module_qualname)
class NodePathTracer(LeafModuleAwareTracer):
"""
NodePathTracer is an FX tracer that, for each operation, also records the
name of the Node from which the operation originated. A node name here is
a `.` seperated path walking the hierarchy from top level module down to
leaf operation or leaf module. The name of the top level module is not
included as part of the node name. For example, if we trace a module whose
forward method applies a ReLU module, the name for that node will simply
be 'relu'.
Some notes on the specifics:
- Nodes are recorded to `self.node_to_qualname` which is a dictionary
mapping a given Node object to its node name.
- Nodes are recorded in the order which they are executed during
tracing.
- When a duplicate node name is encountered, a suffix of the form
_{int} is added. The counter starts from 1.
"""
def __init__(self, *args, **kwargs):
super(NodePathTracer, self).__init__(*args, **kwargs)
# Track the qualified name of the Node being traced
self.current_module_qualname = ''
# A map from FX Node to the qualified name\#
# NOTE: This is loosely like the "qualified name" mentioned in the
# torch.fx docs https://pytorch.org/docs/stable/fx.html but adapted
# for the purposes of the torchvision feature extractor
self.node_to_qualname = OrderedDict()
def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs):
"""
Override of `fx.Tracer.call_module`
This override:
1) Stores away the qualified name of the caller for restoration later
2) Adds the qualified name of the caller to
`current_module_qualname` for retrieval by `create_proxy`
3) Once a leaf module is reached, calls `create_proxy`
4) Restores the caller's qualified name into current_module_qualname
"""
old_qualname = self.current_module_qualname
try:
module_qualname = self.path_of_module(m)
self.current_module_qualname = module_qualname
if not self.is_leaf_module(m, module_qualname):
out = forward(*args, **kwargs)
return out
return self.create_proxy('call_module', module_qualname, args, kwargs)
finally:
self.current_module_qualname = old_qualname
def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs,
name=None, type_expr=None, *_) -> fx.proxy.Proxy:
"""
Override of `Tracer.create_proxy`. This override intercepts the recording
of every operation and stores away the current traced module's qualified
name in `node_to_qualname`
"""
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr)
self.node_to_qualname[proxy.node] = self._get_node_qualname(
self.current_module_qualname, proxy.node)
return proxy
def _get_node_qualname(
self, module_qualname: str, node: fx.node.Node) -> str:
node_qualname = module_qualname
if node.op == 'call_module':
# Node terminates in a leaf module so the module_qualname is a
# complete description of the node
for existing_qualname in reversed(self.node_to_qualname.values()):
# Check to see if existing_qualname is of the form
# {node_qualname} or {node_qualname}_{int}
if re.match(rf'{node_qualname}(_[0-9]+)?$',
existing_qualname) is not None:
postfix = existing_qualname.replace(node_qualname, '')
if len(postfix):
# Existing_qualname is of the form {node_qualname}_{int}
next_index = int(postfix[1:]) + 1
else:
# existing_qualname is of the form {node_qualname}
next_index = 1
node_qualname += f'_{next_index}'
break
pass
else:
# Node terminates in non- leaf module so the node name needs to be
# appended
if len(node_qualname) > 0:
# Only append '.' if we are deeper than the top level module
node_qualname += '.'
node_qualname += str(node)
return node_qualname
def _is_subseq(x, y):
"""Check if y is a subseqence of x
https://stackoverflow.com/a/24017747/4391249
"""
iter_x = iter(x)
return all(any(x_item == y_item for x_item in iter_x) for y_item in y)
def _warn_graph_differences(
train_tracer: NodePathTracer, eval_tracer: NodePathTracer):
"""
Utility function for warning the user if there are differences between
the train graph nodes and the eval graph nodes.
"""
train_nodes = list(train_tracer.node_to_qualname.values())
eval_nodes = list(eval_tracer.node_to_qualname.values())
if len(train_nodes) == len(eval_nodes) and all(
t == e for t, e in zip(train_nodes, eval_nodes)):
return
suggestion_msg = (
"When choosing nodes for feature extraction, you may need to specify "
"output nodes for train and eval mode separately.")
if _is_subseq(train_nodes, eval_nodes):
msg = ("NOTE: The nodes obtained by tracing the model in eval mode "
"are a subsequence of those obtained in train mode. ")
elif _is_subseq(eval_nodes, train_nodes):
msg = ("NOTE: The nodes obtained by tracing the model in train mode "
"are a subsequence of those obtained in eval mode. ")
else:
msg = ("The nodes obtained by tracing the model in train mode "
"are different to those obtained in eval mode. ")
warnings.warn(msg + suggestion_msg)
def get_graph_node_names(
model: nn.Module, tracer_kwargs: Dict = {},
suppress_diff_warning: bool = False) -> Tuple[List[str], List[str]]:
"""
Dev utility to return node names in order of execution. See note on node
names under :func:`create_feature_extractor`. Useful for seeing which node
names are available for feature extraction. There are two reasons that
node names can't easily be read directly from the code for a model:
1. Not all submodules are traced through. Modules from `torch.nn` all
fall within this category.
2. Nodes representing the repeated application of the same operation
or leaf module get a `_{counter}` postfix.
The model is traced twice: once in train mode, and once in eval mode. Both
sets of nodes are returned.
Args:
model (nn.Module): model for which we'd like to print node names
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
`NodePathTracer` (they are eventually passed onto
`torch.fx.Tracer`).
suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of
the graph. Defaults to False.
Returns:
tuple(list, list): a list of node names from tracing the model in
train mode, and another from tracing the model in eval mode.
Examples::
>>> model = torchvision.models.resnet18()
>>> train_nodes, eval_nodes = get_graph_node_names(model)
"""
is_training = model.training
train_tracer = NodePathTracer(**tracer_kwargs)
train_tracer.trace(model.train())
eval_tracer = NodePathTracer(**tracer_kwargs)
eval_tracer.trace(model.eval())
train_nodes = list(train_tracer.node_to_qualname.values())
eval_nodes = list(eval_tracer.node_to_qualname.values())
if not suppress_diff_warning:
_warn_graph_differences(train_tracer, eval_tracer)
# Restore training state
model.train(is_training)
return train_nodes, eval_nodes
class DualGraphModule(fx.GraphModule):
"""
A derivative of `fx.GraphModule`. Differs in the following ways:
- Requires a train and eval version of the underlying graph
- Copies submodules according to the nodes of both train and eval graphs.
- Calling train(mode) switches between train graph and eval graph.
"""
def __init__(self,
root: torch.nn.Module,
train_graph: fx.Graph,
eval_graph: fx.Graph,
class_name: str = 'GraphModule'):
"""
Args:
root (nn.Module): module from which the copied module hierarchy is
built
train_graph (fx.Graph): the graph that should be used in train mode
eval_graph (fx.Graph): the graph that should be used in eval mode
"""
super(fx.GraphModule, self).__init__()
self.__class__.__name__ = class_name
self.train_graph = train_graph
self.eval_graph = eval_graph
# Copy all get_attr and call_module ops (indicated by BOTH train and
# eval graphs)
for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)):
if node.op in ['get_attr', 'call_module']:
assert isinstance(node.target, str)
_copy_attr(root, self, node.target)
# train mode by default
self.train()
self.graph = train_graph
# (borrowed from fx.GraphModule):
# Store the Tracer class responsible for creating a Graph separately as part of the
# GraphModule state, except when the Tracer is defined in a local namespace.
# Locally defined Tracers are not pickleable. This is needed because torch.package will
# serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
# to re-create the Graph during deserialization.
assert self.eval_graph._tracer_cls == self.train_graph._tracer_cls, \
"Train mode and eval mode should use the same tracer class"
self._tracer_cls = None
if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__:
self._tracer_cls = self.graph._tracer_cls
def train(self, mode=True):
"""
Swap out the graph depending on the selected training mode.
NOTE this should be safe when calling model.eval() because that just
calls this with mode == False.
"""
# NOTE: Only set self.graph if the current graph is not the desired
# one. This saves us from recompiling the graph where not necessary.
if mode and not self.training:
self.graph = self.train_graph
elif not mode and self.training:
self.graph = self.eval_graph
return super().train(mode=mode)
def create_feature_extractor(
model: nn.Module,
return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
tracer_kwargs: Dict = {},
suppress_diff_warning: bool = False) -> fx.GraphModule:
"""
Creates a new graph module that returns intermediate nodes from a given
model as dictionary with user specified keys as strings, and the requested
outputs as values. This is achieved by re-writing the computation graph of
the model via FX to return the desired nodes as outputs. All unused nodes
are removed, together with their corresponding parameters.
A note on node specification: For the purposes of this feature extraction
utility, a node name is specified as a `.` seperated path walking the
hierarchy from top level module down to leaf operation or leaf module. For
instance `blocks.5.3.bn1`. The keys of the `return_nodes` argument should
point to either a node's name, or some truncated version of it. For
example, one could provide `blocks.5` as a key, and the last node with
that prefix will be selected. :func:`get_graph_node_names` is a useful
helper function for getting a list of node names of a model.
Not all models will be FX traceable, although with some massaging they can
be made to cooperate. Here's a (not exhaustive) list of tips:
- If you don't need to trace through a particular, problematic
sub-module, turn it into a "leaf module" by passing a list of
`leaf_modules` as one of the `tracer_kwargs` (see example below). It
will not be traced through, but rather, the resulting graph will
hold a reference to that module's forward method.
- Likewise, you may turn functions into leaf functions by passing a
list of `autowrap_functions` as one of the `tracer_kwargs` (see
example below).
- Some inbuilt Python functions can be problematic. For instance,
`int` will raise an error during tracing. You may wrap them in your
own function and then pass that in `autowrap_functions` as one of
the `tracer_kwargs`.
For further information on FX see the
`torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_.
Args:
model (nn.Module): model on which we will extract the features
return_nodes (list or dict, optional): either a `List` or a `Dict`
containing the names (or partial names - see note above)
of the nodes for which the activations will be returned. If it is
a `Dict`, the keys are the node names, and the values
are the user-specified keys for the graph module's returned
dictionary. If it is a `List`, it is treated as a `Dict` mapping
node specification strings directly to output names. In the case
that `train_return_nodes` and `eval_return_nodes` are specified,
this should not be specified.
train_return_nodes (list or dict, optional): similar to
`return_nodes`. This can be used if the return nodes
for train mode are different than those from eval mode.
If this is specified, `eval_return_nodes` must also be specified,
and `return_nodes` should not be specified.
eval_return_nodes (list or dict, optional): similar to
`return_nodes`. This can be used if the return nodes
for train mode are different than those from eval mode.
If this is specified, `train_return_nodes` must also be specified,
and `return_nodes` should not be specified.
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
`NodePathTracer` (which passes them onto it's parent class
`torch.fx.Tracer`).
suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of
the graph. Defaults to False.
Examples::
>>> # Feature extraction with resnet
>>> model = torchvision.models.resnet18()
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> model = create_feature_extractor(
>>> model, {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = model(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>> [('feat1', torch.Size([1, 64, 56, 56])),
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
>>> # Specifying leaf modules and leaf functions
>>> def leaf_function(x):
>>> # This would raise a TypeError if traced through
>>> return int(x)
>>>
>>> class LeafModule(torch.nn.Module):
>>> def forward(self, x):
>>> # This would raise a TypeError if traced through
>>> int(x.shape[0])
>>> return torch.nn.functional.relu(x + 4)
>>>
>>> class MyModule(torch.nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.conv = torch.nn.Conv2d(3, 1, 3)
>>> self.leaf_module = LeafModule()
>>>
>>> def forward(self, x):
>>> leaf_function(x.shape[0])
>>> x = self.conv(x)
>>> return self.leaf_module(x)
>>>
>>> model = create_feature_extractor(
>>> MyModule(), return_nodes=['leaf_module'],
>>> tracer_kwargs={'leaf_modules': [LeafModule],
>>> 'autowrap_functions': [leaf_function]})
"""
is_training = model.training
assert any(arg is not None for arg in [
return_nodes, train_return_nodes, eval_return_nodes]), (
"Either `return_nodes` or `train_return_nodes` and "
"`eval_return_nodes` together, should be specified")
assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), \
("If any of `train_return_nodes` and `eval_return_nodes` are "
"specified, then both should be specified")
assert ((return_nodes is None) ^ (train_return_nodes is None)), \
("If `train_return_nodes` and `eval_return_nodes` are specified, "
"then both should be specified")
# Put *_return_nodes into Dict[str, str] format
def to_strdict(n) -> Dict[str, str]:
if isinstance(n, list):
return {str(i): str(i) for i in n}
return {str(k): str(v) for k, v in n.items()}
if train_return_nodes is None:
return_nodes = to_strdict(return_nodes)
train_return_nodes = deepcopy(return_nodes)
eval_return_nodes = deepcopy(return_nodes)
else:
train_return_nodes = to_strdict(train_return_nodes)
eval_return_nodes = to_strdict(eval_return_nodes)
# Repeat the tracing and graph rewriting for train and eval mode
tracers = {}
graphs = {}
mode_return_nodes: Dict[str, Dict[str, str]] = {
'train': train_return_nodes,
'eval': eval_return_nodes
}
for mode in ['train', 'eval']:
if mode == 'train':
model.train()
elif mode == 'eval':
model.eval()
# Instantiate our NodePathTracer and use that to trace the model
tracer = NodePathTracer(**tracer_kwargs)
graph = tracer.trace(model)
name = model.__class__.__name__ if isinstance(
model, nn.Module) else model.__name__
graph_module = fx.GraphModule(tracer.root, graph, name)
available_nodes = list(tracer.node_to_qualname.values())
# FIXME We don't know if we should expect this to happen
assert len(set(available_nodes)) == len(available_nodes), \
"There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues"
# Check that all outputs in return_nodes are present in the model
for query in mode_return_nodes[mode].keys():
# To check if a query is available we need to check that at least
# one of the available names starts with it up to a .
if not any([re.match(rf'^{query}(\.|$)', n) is not None
for n in available_nodes]):
raise ValueError(
f"node: '{query}' is not present in model. Hint: use "
"`get_graph_node_names` to make sure the "
"`return_nodes` you specified are present. It may even "
"be that you need to specify `train_return_nodes` and "
"`eval_return_nodes` separately.")
# Remove existing output nodes (train mode)
orig_output_nodes = []
for n in reversed(graph_module.graph.nodes):
if n.op == 'output':
orig_output_nodes.append(n)
assert len(orig_output_nodes)
for n in orig_output_nodes:
graph_module.graph.erase_node(n)
# Find nodes corresponding to return_nodes and make them into output_nodes
nodes = [n for n in graph_module.graph.nodes]
output_nodes = OrderedDict()
for n in reversed(nodes):
module_qualname = tracer.node_to_qualname.get(n)
if module_qualname is None:
# NOTE - Know cases where this happens:
# - Node representing creation of a tensor constant - probably
# not interesting as a return node
# - When packing outputs into a named tuple like in InceptionV3
continue
for query in mode_return_nodes[mode]:
depth = query.count('.')
if '.'.join(module_qualname.split('.')[:depth + 1]) == query:
output_nodes[mode_return_nodes[mode][query]] = n
mode_return_nodes[mode].pop(query)
break
output_nodes = OrderedDict(reversed(list(output_nodes.items())))
# And add them in the end of the graph
with graph_module.graph.inserting_after(nodes[-1]):
graph_module.graph.output(output_nodes)
# Remove unused modules / parameters
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
# Keep track of the tracer and graph so we can choose the main one
tracers[mode] = tracer
graphs[mode] = graph
# Warn user if there are any discrepancies between the graphs of the
# train and eval modes
if not suppress_diff_warning:
_warn_graph_differences(tracers['train'], tracers['eval'])
# Build the final graph module
graph_module = DualGraphModule(
model, graphs['train'], graphs['eval'], class_name=name)
# Restore original training mode
model.train(is_training)
graph_module.train(is_training)
return graph_module
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment