feature_extraction.rst 4.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
torchvision.models.feature_extraction
=====================================

.. currentmodule:: torchvision.models.feature_extraction

Feature extraction utilities let us tap into our models to access intermediate
transformations of our inputs. This could be useful for a variety of
applications in computer vision. Just a few examples are:

- Visualizing feature maps.
- Extracting features to compute image descriptors for tasks like facial
  recognition, copy-detection, or image retrieval.
- Passing selected features to downstream sub-networks for end-to-end training
  with a specific task in mind. For example, passing a hierarchy of features
  to a Feature Pyramid Network with object detection heads.

Torchvision provides :func:`create_feature_extractor` for this purpose.
It works by following roughly these steps:

1. Symbolically tracing the model to get a graphical representation of
   how it transforms the input, step by step.
2. Setting the user-selected graph nodes as ouputs.
3. Removing all redundant nodes (anything downstream of the ouput nodes).
4. Generating python code from the resulting graph and bundling that into a
   PyTorch module together with the graph itself.

|

The `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_
provides a more general and detailed explanation of the above procedure and
the inner workings of the symbolic tracing.

Here is an example of how we might extract features for MaskRCNN:

.. code-block:: python

  import torch
  from torchvision.models import resnet50
  from torchvision.models.feature_extraction import get_graph_node_names
  from torchvision.models.feature_extraction import create_feature_extractor
  from torchvision.models.detection.mask_rcnn import MaskRCNN
42
  from torchvision.models.detection.backbone_utils import LastLevelMaxPool
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
  from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork


  # To assist you in designing the feature extractor you may want to print out
  # the available nodes for resnet50.
  m = resnet50()
  train_nodes, eval_nodes = get_graph_node_names(resnet50())

  # The lists returned, are the names of all the graph nodes (in order of
  # execution) for the input model traced in train mode and in eval mode
  # respectively. You'll find that `train_nodes` and `eval_nodes` are the same
  # for this example. But if the model contains control flow that's dependent
  # on the training mode, they may be different.

  # To specify the nodes you want to extract, you could select the final node
  # that appears in each of the main layers:
  return_nodes = {
      # node_name: user-specified key for output dict
61
      'layer1.2.relu_2': 'layer1',
62
63
64
65
66
67
68
69
70
71
72
73
      'layer2.3.relu_2': 'layer2',
      'layer3.5.relu_2': 'layer3',
      'layer4.2.relu_2': 'layer4',
  }

  # But `create_feature_extractor` can also accept truncated node specifications
  # like "layer1", as it will just pick the last node that's a descendent of
  # of the specification. (Tip: be careful with this, especially when a layer
  # has multiple outputs. It's not always guaranteed that the last operation
  # performed is the one that corresponds to the output you desire. You should
  # consult the source code for the input model to confirm.)
  return_nodes = {
74
      'layer1': 'layer1',
75
76
77
78
79
80
81
82
      'layer2': 'layer2',
      'layer3': 'layer3',
      'layer4': 'layer4',
  }

  # Now you can build the feature extractor. This returns a module whose forward
  # method returns a dictionary like:
  # {
83
  #     'layer1': ouput of layer 1,
84
85
86
87
88
89
90
91
92
93
94
95
96
97
  #     'layer2': ouput of layer 2,
  #     'layer3': ouput of layer 3,
  #     'layer4': ouput of layer 4,
  # }
  create_feature_extractor(m, return_nodes=return_nodes)

  # Let's put all that together to wrap resnet50 with MaskRCNN

  # MaskRCNN requires a backbone with an attached FPN
  class Resnet50WithFPN(torch.nn.Module):
      def __init__(self):
          super(Resnet50WithFPN, self).__init__()
          # Get a resnet50 backbone
          m = resnet50()
98
99
          # Extract 4 main layers (note: MaskRCNN needs this particular name
          # mapping for return nodes)
100
          self.body = create_feature_extractor(
101
102
              m, return_nodes={f'layer{k}': str(v)
                               for v, k in enumerate([1, 2, 3, 4])})
103
104
105
106
107
108
109
110
          # Dry run to get number of channels for FPN
          inp = torch.randn(2, 3, 224, 224)
          with torch.no_grad():
              out = self.body(inp)
          in_channels_list = [o.shape[1] for o in out.values()]
          # Build FPN
          self.out_channels = 256
          self.fpn = FeaturePyramidNetwork(
111
112
              in_channels_list, out_channels=self.out_channels,
              extra_blocks=LastLevelMaxPool())
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

      def forward(self, x):
          x = self.body(x)
          x = self.fpn(x)
          return x


  # Now we can build our model!
  model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval()


API Reference
-------------

.. autofunction:: create_feature_extractor

129
.. autofunction:: get_graph_node_names