Unverified Commit 903ea4a6 authored by Alexander Soare's avatar Alexander Soare Committed by GitHub
Browse files

Improve FX node naming (#4418)



* draft commit

* Polish and add corresponding test

* Update docs

* Update torchvision/models/feature_extraction.py

* Update docs/source/feature_extraction.rst
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 6518372e
...@@ -19,8 +19,8 @@ It works by following roughly these steps: ...@@ -19,8 +19,8 @@ It works by following roughly these steps:
1. Symbolically tracing the model to get a graphical representation of 1. Symbolically tracing the model to get a graphical representation of
how it transforms the input, step by step. how it transforms the input, step by step.
2. Setting the user-selected graph nodes as ouputs. 2. Setting the user-selected graph nodes as outputs.
3. Removing all redundant nodes (anything downstream of the ouput nodes). 3. Removing all redundant nodes (anything downstream of the output nodes).
4. Generating python code from the resulting graph and bundling that into a 4. Generating python code from the resulting graph and bundling that into a
PyTorch module together with the graph itself. PyTorch module together with the graph itself.
...@@ -30,6 +30,39 @@ The `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_ ...@@ -30,6 +30,39 @@ The `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_
provides a more general and detailed explanation of the above procedure and provides a more general and detailed explanation of the above procedure and
the inner workings of the symbolic tracing. the inner workings of the symbolic tracing.
.. _about-node-names:
**About Node Names**
In order to specify which nodes should be output nodes for extracted
features, one should be familiar with the node naming convention used here
(which differs slightly from that used in ``torch.fx``). A node name is
specified as a ``.`` separated path walking the module hierarchy from top level
module down to leaf operation or leaf module. For instance ``"layer4.2.relu"``
in ResNet-50 represents the output of the ReLU of the 2nd block of the 4th
layer of the ``ResNet`` module. Here are some finer points to keep in mind:
- When specifying node names for :func:`create_feature_extractor`, you may
provide a truncated version of a node name as a shortcut. To see how this
works, try creating a ResNet-50 model and printing the node names with
``train_nodes, _ = get_graph_node_names(model) print(train_nodes)`` and
observe that the last node pertaining to ``layer4`` is
``"layer4.2.relu_2"``. One may specify ``"layer4.2.relu_2"`` as the return
node, or just ``"layer4"`` as this, by convention, refers to the last node
(in order of execution) of ``layer4``.
- If a certain module or operation is repeated more than once, node names get
an additional ``_{int}`` postfix to disambiguate. For instance, maybe the
addition (``+``) operation is used three times in the same ``forward``
method. Then there would be ``"path.to.module.add"``,
``"path.to.module.add_1"``, ``"path.to.module.add_2"``. The counter is
maintained within the scope of the direct parent. So in ResNet-50 there is
a ``"layer4.1.add"`` and a ``"layer4.2.add"``. Because the addition
operations reside in different blocks, there is no need for a postfix to
disambiguate.
**An Example**
Here is an example of how we might extract features for MaskRCNN: Here is an example of how we might extract features for MaskRCNN:
.. code-block:: python .. code-block:: python
...@@ -80,10 +113,10 @@ Here is an example of how we might extract features for MaskRCNN: ...@@ -80,10 +113,10 @@ Here is an example of how we might extract features for MaskRCNN:
# Now you can build the feature extractor. This returns a module whose forward # Now you can build the feature extractor. This returns a module whose forward
# method returns a dictionary like: # method returns a dictionary like:
# { # {
# 'layer1': ouput of layer 1, # 'layer1': output of layer 1,
# 'layer2': ouput of layer 2, # 'layer2': output of layer 2,
# 'layer3': ouput of layer 3, # 'layer3': output of layer 3,
# 'layer4': ouput of layer 4, # 'layer4': output of layer 4,
# } # }
create_feature_extractor(m, return_nodes=return_nodes) create_feature_extractor(m, return_nodes=return_nodes)
......
...@@ -33,6 +33,41 @@ def leaf_function(x): ...@@ -33,6 +33,41 @@ def leaf_function(x):
return int(x) return int(x)
# Needed by TestFXFeatureExtraction. Checking that node naming conventions
# are respected. Particularly the index postfix of repeated node names
class TestSubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
x = x + 1
x = x + 1
x = self.relu(x)
x = self.relu(x)
return x
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.submodule = TestSubModule()
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.submodule(x)
x = x + 1
x = x + 1
x = self.relu(x)
x = self.relu(x)
return x
test_module_nodes = [
'x', 'submodule.add', 'submodule.add_1', 'submodule.relu',
'submodule.relu_1', 'add', 'add_1', 'relu', 'relu_1']
class TestFxFeatureExtraction: class TestFxFeatureExtraction:
inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device='cpu') inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device='cpu')
model_defaults = { model_defaults = {
...@@ -104,6 +139,11 @@ class TestFxFeatureExtraction: ...@@ -104,6 +139,11 @@ class TestFxFeatureExtraction:
else: # otherwise skip this check else: # otherwise skip this check
raise ValueError raise ValueError
def test_node_name_conventions(self):
model = TestModule()
train_nodes, _ = get_graph_node_names(model)
assert all(a == b for a, b in zip(train_nodes, test_module_nodes))
@pytest.mark.parametrize('model_name', get_available_models()) @pytest.mark.parametrize('model_name', get_available_models())
def test_forward_backward(self, model_name): def test_forward_backward(self, model_name):
model = models.__dict__[model_name](**self.model_defaults).train() model = models.__dict__[model_name](**self.model_defaults).train()
......
...@@ -97,9 +97,24 @@ class NodePathTracer(LeafModuleAwareTracer): ...@@ -97,9 +97,24 @@ class NodePathTracer(LeafModuleAwareTracer):
def _get_node_qualname( def _get_node_qualname(
self, module_qualname: str, node: fx.node.Node) -> str: self, module_qualname: str, node: fx.node.Node) -> str:
node_qualname = module_qualname node_qualname = module_qualname
if node.op == 'call_module':
# Node terminates in a leaf module so the module_qualname is a if node.op != 'call_module':
# complete description of the node # In this case module_qualname from torch.fx doesn't go all the
# way to the leaf function/op so we need to append it
if len(node_qualname) > 0:
# Only append '.' if we are deeper than the top level module
node_qualname += '.'
node_qualname += str(node)
# Now we need to add an _{index} postfix on any repeated node names
# For modules we do this from scratch
# But for anything else, torch.fx already has a globally scoped
# _{index} postfix. But we want it locally (relative to direct parent)
# scoped. So first we need to undo the torch.fx postfix
if re.match(r'.+_[0-9]+$', node_qualname) is not None:
node_qualname = node_qualname.rsplit('_', 1)[0]
# ... and now we add on our own postfix
for existing_qualname in reversed(self.node_to_qualname.values()): for existing_qualname in reversed(self.node_to_qualname.values()):
# Check to see if existing_qualname is of the form # Check to see if existing_qualname is of the form
# {node_qualname} or {node_qualname}_{int} # {node_qualname} or {node_qualname}_{int}
...@@ -107,21 +122,14 @@ class NodePathTracer(LeafModuleAwareTracer): ...@@ -107,21 +122,14 @@ class NodePathTracer(LeafModuleAwareTracer):
existing_qualname) is not None: existing_qualname) is not None:
postfix = existing_qualname.replace(node_qualname, '') postfix = existing_qualname.replace(node_qualname, '')
if len(postfix): if len(postfix):
# Existing_qualname is of the form {node_qualname}_{int} # existing_qualname is of the form {node_qualname}_{int}
next_index = int(postfix[1:]) + 1 next_index = int(postfix[1:]) + 1
else: else:
# existing_qualname is of the form {node_qualname} # existing_qualname is of the form {node_qualname}
next_index = 1 next_index = 1
node_qualname += f'_{next_index}' node_qualname += f'_{next_index}'
break break
pass
else:
# Node terminates in non- leaf module so the node name needs to be
# appended
if len(node_qualname) > 0:
# Only append '.' if we are deeper than the top level module
node_qualname += '.'
node_qualname += str(node)
return node_qualname return node_qualname
...@@ -171,19 +179,23 @@ def get_graph_node_names( ...@@ -171,19 +179,23 @@ def get_graph_node_names(
names are available for feature extraction. There are two reasons that names are available for feature extraction. There are two reasons that
node names can't easily be read directly from the code for a model: node names can't easily be read directly from the code for a model:
1. Not all submodules are traced through. Modules from `torch.nn` all 1. Not all submodules are traced through. Modules from ``torch.nn`` all
fall within this category. fall within this category.
2. Nodes representing the repeated application of the same operation 2. Nodes representing the repeated application of the same operation
or leaf module get a `_{counter}` postfix. or leaf module get a ``_{counter}`` postfix.
The model is traced twice: once in train mode, and once in eval mode. Both The model is traced twice: once in train mode, and once in eval mode. Both
sets of nodes are returned. sets of node names are returned.
For more details on the node naming conventions used here, please see the
:ref:`relevant subheading <about-node-names>` in the
`documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_.
Args: Args:
model (nn.Module): model for which we'd like to print node names model (nn.Module): model for which we'd like to print node names
tracer_kwargs (dict, optional): a dictionary of keywork arguments for tracer_kwargs (dict, optional): a dictionary of keywork arguments for
`NodePathTracer` (they are eventually passed onto ``NodePathTracer`` (they are eventually passed onto
`torch.fx.Tracer`). `torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
suppress_diff_warning (bool, optional): whether to suppress a warning suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of when there are discrepancies between the train and eval version of
the graph. Defaults to False. the graph. Defaults to False.
...@@ -289,58 +301,55 @@ def create_feature_extractor( ...@@ -289,58 +301,55 @@ def create_feature_extractor(
the model via FX to return the desired nodes as outputs. All unused nodes the model via FX to return the desired nodes as outputs. All unused nodes
are removed, together with their corresponding parameters. are removed, together with their corresponding parameters.
A note on node specification: For the purposes of this feature extraction Desired output nodes must be specified as a ``.`` separated
utility, a node name is specified as a `.` seperated path walking the path walking the module hierarchy from top level module down to leaf
hierarchy from top level module down to leaf operation or leaf module. For operation or leaf module. For more details on the node naming conventions
instance `blocks.5.3.bn1`. The keys of the `return_nodes` argument should used here, please see the :ref:`relevant subheading <about-node-names>`
point to either a node's name, or some truncated version of it. For in the `documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_.
example, one could provide `blocks.5` as a key, and the last node with
that prefix will be selected. :func:`get_graph_node_names` is a useful
helper function for getting a list of node names of a model.
Not all models will be FX traceable, although with some massaging they can Not all models will be FX traceable, although with some massaging they can
be made to cooperate. Here's a (not exhaustive) list of tips: be made to cooperate. Here's a (not exhaustive) list of tips:
- If you don't need to trace through a particular, problematic - If you don't need to trace through a particular, problematic
sub-module, turn it into a "leaf module" by passing a list of sub-module, turn it into a "leaf module" by passing a list of
`leaf_modules` as one of the `tracer_kwargs` (see example below). It ``leaf_modules`` as one of the ``tracer_kwargs`` (see example below).
will not be traced through, but rather, the resulting graph will It will not be traced through, but rather, the resulting graph will
hold a reference to that module's forward method. hold a reference to that module's forward method.
- Likewise, you may turn functions into leaf functions by passing a - Likewise, you may turn functions into leaf functions by passing a
list of `autowrap_functions` as one of the `tracer_kwargs` (see list of ``autowrap_functions`` as one of the ``tracer_kwargs`` (see
example below). example below).
- Some inbuilt Python functions can be problematic. For instance, - Some inbuilt Python functions can be problematic. For instance,
`int` will raise an error during tracing. You may wrap them in your ``int`` will raise an error during tracing. You may wrap them in your
own function and then pass that in `autowrap_functions` as one of own function and then pass that in ``autowrap_functions`` as one of
the `tracer_kwargs`. the ``tracer_kwargs``.
For further information on FX see the For further information on FX see the
`torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_. `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_.
Args: Args:
model (nn.Module): model on which we will extract the features model (nn.Module): model on which we will extract the features
return_nodes (list or dict, optional): either a `List` or a `Dict` return_nodes (list or dict, optional): either a ``List`` or a ``Dict``
containing the names (or partial names - see note above) containing the names (or partial names - see note above)
of the nodes for which the activations will be returned. If it is of the nodes for which the activations will be returned. If it is
a `Dict`, the keys are the node names, and the values a ``Dict``, the keys are the node names, and the values
are the user-specified keys for the graph module's returned are the user-specified keys for the graph module's returned
dictionary. If it is a `List`, it is treated as a `Dict` mapping dictionary. If it is a ``List``, it is treated as a ``Dict`` mapping
node specification strings directly to output names. In the case node specification strings directly to output names. In the case
that `train_return_nodes` and `eval_return_nodes` are specified, that ``train_return_nodes`` and ``eval_return_nodes`` are specified,
this should not be specified. this should not be specified.
train_return_nodes (list or dict, optional): similar to train_return_nodes (list or dict, optional): similar to
`return_nodes`. This can be used if the return nodes ``return_nodes``. This can be used if the return nodes
for train mode are different than those from eval mode. for train mode are different than those from eval mode.
If this is specified, `eval_return_nodes` must also be specified, If this is specified, ``eval_return_nodes`` must also be specified,
and `return_nodes` should not be specified. and ``return_nodes`` should not be specified.
eval_return_nodes (list or dict, optional): similar to eval_return_nodes (list or dict, optional): similar to
`return_nodes`. This can be used if the return nodes ``return_nodes``. This can be used if the return nodes
for train mode are different than those from eval mode. for train mode are different than those from eval mode.
If this is specified, `train_return_nodes` must also be specified, If this is specified, ``train_return_nodes`` must also be specified,
and `return_nodes` should not be specified. and `return_nodes` should not be specified.
tracer_kwargs (dict, optional): a dictionary of keywork arguments for tracer_kwargs (dict, optional): a dictionary of keywork arguments for
`NodePathTracer` (which passes them onto it's parent class ``NodePathTracer`` (which passes them onto it's parent class
`torch.fx.Tracer`). `torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
suppress_diff_warning (bool, optional): whether to suppress a warning suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of when there are discrepancies between the train and eval version of
the graph. Defaults to False. the graph. Defaults to False.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment