Unverified Commit 903ea4a6 authored by Alexander Soare's avatar Alexander Soare Committed by GitHub
Browse files

Improve FX node naming (#4418)



* draft commit

* Polish and add corresponding test

* Update docs

* Update torchvision/models/feature_extraction.py

* Update docs/source/feature_extraction.rst
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 6518372e
......@@ -19,8 +19,8 @@ It works by following roughly these steps:
1. Symbolically tracing the model to get a graphical representation of
how it transforms the input, step by step.
2. Setting the user-selected graph nodes as ouputs.
3. Removing all redundant nodes (anything downstream of the ouput nodes).
2. Setting the user-selected graph nodes as outputs.
3. Removing all redundant nodes (anything downstream of the output nodes).
4. Generating python code from the resulting graph and bundling that into a
PyTorch module together with the graph itself.
......@@ -30,6 +30,39 @@ The `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_
provides a more general and detailed explanation of the above procedure and
the inner workings of the symbolic tracing.
.. _about-node-names:
**About Node Names**
In order to specify which nodes should be output nodes for extracted
features, one should be familiar with the node naming convention used here
(which differs slightly from that used in ``torch.fx``). A node name is
specified as a ``.`` separated path walking the module hierarchy from top level
module down to leaf operation or leaf module. For instance ``"layer4.2.relu"``
in ResNet-50 represents the output of the ReLU of the 2nd block of the 4th
layer of the ``ResNet`` module. Here are some finer points to keep in mind:
- When specifying node names for :func:`create_feature_extractor`, you may
provide a truncated version of a node name as a shortcut. To see how this
works, try creating a ResNet-50 model and printing the node names with
``train_nodes, _ = get_graph_node_names(model) print(train_nodes)`` and
observe that the last node pertaining to ``layer4`` is
``"layer4.2.relu_2"``. One may specify ``"layer4.2.relu_2"`` as the return
node, or just ``"layer4"`` as this, by convention, refers to the last node
(in order of execution) of ``layer4``.
- If a certain module or operation is repeated more than once, node names get
an additional ``_{int}`` postfix to disambiguate. For instance, maybe the
addition (``+``) operation is used three times in the same ``forward``
method. Then there would be ``"path.to.module.add"``,
``"path.to.module.add_1"``, ``"path.to.module.add_2"``. The counter is
maintained within the scope of the direct parent. So in ResNet-50 there is
a ``"layer4.1.add"`` and a ``"layer4.2.add"``. Because the addition
operations reside in different blocks, there is no need for a postfix to
disambiguate.
**An Example**
Here is an example of how we might extract features for MaskRCNN:
.. code-block:: python
......@@ -80,10 +113,10 @@ Here is an example of how we might extract features for MaskRCNN:
# Now you can build the feature extractor. This returns a module whose forward
# method returns a dictionary like:
# {
# 'layer1': ouput of layer 1,
# 'layer2': ouput of layer 2,
# 'layer3': ouput of layer 3,
# 'layer4': ouput of layer 4,
# 'layer1': output of layer 1,
# 'layer2': output of layer 2,
# 'layer3': output of layer 3,
# 'layer4': output of layer 4,
# }
create_feature_extractor(m, return_nodes=return_nodes)
......
......@@ -33,6 +33,41 @@ def leaf_function(x):
return int(x)
# Needed by TestFXFeatureExtraction. Checking that node naming conventions
# are respected. Particularly the index postfix of repeated node names
class TestSubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
x = x + 1
x = x + 1
x = self.relu(x)
x = self.relu(x)
return x
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.submodule = TestSubModule()
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.submodule(x)
x = x + 1
x = x + 1
x = self.relu(x)
x = self.relu(x)
return x
test_module_nodes = [
'x', 'submodule.add', 'submodule.add_1', 'submodule.relu',
'submodule.relu_1', 'add', 'add_1', 'relu', 'relu_1']
class TestFxFeatureExtraction:
inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device='cpu')
model_defaults = {
......@@ -104,6 +139,11 @@ class TestFxFeatureExtraction:
else: # otherwise skip this check
raise ValueError
def test_node_name_conventions(self):
model = TestModule()
train_nodes, _ = get_graph_node_names(model)
assert all(a == b for a, b in zip(train_nodes, test_module_nodes))
@pytest.mark.parametrize('model_name', get_available_models())
def test_forward_backward(self, model_name):
model = models.__dict__[model_name](**self.model_defaults).train()
......
......@@ -97,9 +97,24 @@ class NodePathTracer(LeafModuleAwareTracer):
def _get_node_qualname(
self, module_qualname: str, node: fx.node.Node) -> str:
node_qualname = module_qualname
if node.op == 'call_module':
# Node terminates in a leaf module so the module_qualname is a
# complete description of the node
if node.op != 'call_module':
# In this case module_qualname from torch.fx doesn't go all the
# way to the leaf function/op so we need to append it
if len(node_qualname) > 0:
# Only append '.' if we are deeper than the top level module
node_qualname += '.'
node_qualname += str(node)
# Now we need to add an _{index} postfix on any repeated node names
# For modules we do this from scratch
# But for anything else, torch.fx already has a globally scoped
# _{index} postfix. But we want it locally (relative to direct parent)
# scoped. So first we need to undo the torch.fx postfix
if re.match(r'.+_[0-9]+$', node_qualname) is not None:
node_qualname = node_qualname.rsplit('_', 1)[0]
# ... and now we add on our own postfix
for existing_qualname in reversed(self.node_to_qualname.values()):
# Check to see if existing_qualname is of the form
# {node_qualname} or {node_qualname}_{int}
......@@ -107,21 +122,14 @@ class NodePathTracer(LeafModuleAwareTracer):
existing_qualname) is not None:
postfix = existing_qualname.replace(node_qualname, '')
if len(postfix):
# Existing_qualname is of the form {node_qualname}_{int}
# existing_qualname is of the form {node_qualname}_{int}
next_index = int(postfix[1:]) + 1
else:
# existing_qualname is of the form {node_qualname}
next_index = 1
node_qualname += f'_{next_index}'
break
pass
else:
# Node terminates in non- leaf module so the node name needs to be
# appended
if len(node_qualname) > 0:
# Only append '.' if we are deeper than the top level module
node_qualname += '.'
node_qualname += str(node)
return node_qualname
......@@ -171,19 +179,23 @@ def get_graph_node_names(
names are available for feature extraction. There are two reasons that
node names can't easily be read directly from the code for a model:
1. Not all submodules are traced through. Modules from `torch.nn` all
1. Not all submodules are traced through. Modules from ``torch.nn`` all
fall within this category.
2. Nodes representing the repeated application of the same operation
or leaf module get a `_{counter}` postfix.
or leaf module get a ``_{counter}`` postfix.
The model is traced twice: once in train mode, and once in eval mode. Both
sets of nodes are returned.
sets of node names are returned.
For more details on the node naming conventions used here, please see the
:ref:`relevant subheading <about-node-names>` in the
`documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_.
Args:
model (nn.Module): model for which we'd like to print node names
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
`NodePathTracer` (they are eventually passed onto
`torch.fx.Tracer`).
``NodePathTracer`` (they are eventually passed onto
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of
the graph. Defaults to False.
......@@ -289,58 +301,55 @@ def create_feature_extractor(
the model via FX to return the desired nodes as outputs. All unused nodes
are removed, together with their corresponding parameters.
A note on node specification: For the purposes of this feature extraction
utility, a node name is specified as a `.` seperated path walking the
hierarchy from top level module down to leaf operation or leaf module. For
instance `blocks.5.3.bn1`. The keys of the `return_nodes` argument should
point to either a node's name, or some truncated version of it. For
example, one could provide `blocks.5` as a key, and the last node with
that prefix will be selected. :func:`get_graph_node_names` is a useful
helper function for getting a list of node names of a model.
Desired output nodes must be specified as a ``.`` separated
path walking the module hierarchy from top level module down to leaf
operation or leaf module. For more details on the node naming conventions
used here, please see the :ref:`relevant subheading <about-node-names>`
in the `documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_.
Not all models will be FX traceable, although with some massaging they can
be made to cooperate. Here's a (not exhaustive) list of tips:
- If you don't need to trace through a particular, problematic
sub-module, turn it into a "leaf module" by passing a list of
`leaf_modules` as one of the `tracer_kwargs` (see example below). It
will not be traced through, but rather, the resulting graph will
``leaf_modules`` as one of the ``tracer_kwargs`` (see example below).
It will not be traced through, but rather, the resulting graph will
hold a reference to that module's forward method.
- Likewise, you may turn functions into leaf functions by passing a
list of `autowrap_functions` as one of the `tracer_kwargs` (see
list of ``autowrap_functions`` as one of the ``tracer_kwargs`` (see
example below).
- Some inbuilt Python functions can be problematic. For instance,
`int` will raise an error during tracing. You may wrap them in your
own function and then pass that in `autowrap_functions` as one of
the `tracer_kwargs`.
``int`` will raise an error during tracing. You may wrap them in your
own function and then pass that in ``autowrap_functions`` as one of
the ``tracer_kwargs``.
For further information on FX see the
`torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_.
Args:
model (nn.Module): model on which we will extract the features
return_nodes (list or dict, optional): either a `List` or a `Dict`
return_nodes (list or dict, optional): either a ``List`` or a ``Dict``
containing the names (or partial names - see note above)
of the nodes for which the activations will be returned. If it is
a `Dict`, the keys are the node names, and the values
a ``Dict``, the keys are the node names, and the values
are the user-specified keys for the graph module's returned
dictionary. If it is a `List`, it is treated as a `Dict` mapping
dictionary. If it is a ``List``, it is treated as a ``Dict`` mapping
node specification strings directly to output names. In the case
that `train_return_nodes` and `eval_return_nodes` are specified,
that ``train_return_nodes`` and ``eval_return_nodes`` are specified,
this should not be specified.
train_return_nodes (list or dict, optional): similar to
`return_nodes`. This can be used if the return nodes
``return_nodes``. This can be used if the return nodes
for train mode are different than those from eval mode.
If this is specified, `eval_return_nodes` must also be specified,
and `return_nodes` should not be specified.
If this is specified, ``eval_return_nodes`` must also be specified,
and ``return_nodes`` should not be specified.
eval_return_nodes (list or dict, optional): similar to
`return_nodes`. This can be used if the return nodes
``return_nodes``. This can be used if the return nodes
for train mode are different than those from eval mode.
If this is specified, `train_return_nodes` must also be specified,
If this is specified, ``train_return_nodes`` must also be specified,
and `return_nodes` should not be specified.
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
`NodePathTracer` (which passes them onto it's parent class
`torch.fx.Tracer`).
``NodePathTracer`` (which passes them onto it's parent class
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of
the graph. Defaults to False.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment