feature_extraction.py 22.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
from typing import Dict, Callable, List, Union, Optional, Tuple
from collections import OrderedDict
import warnings
import re
from copy import deepcopy
from itertools import chain

import torch
from torch import nn
from torch import fx
from torch.fx.graph_module import _copy_attr


__all__ = ['create_feature_extractor', 'get_graph_node_names']


class LeafModuleAwareTracer(fx.Tracer):
    """
    An fx.Tracer that allows the user to specify a set of leaf modules, ie.
    modules that are not to be traced through. The resulting graph ends up
    having single nodes referencing calls to the leaf modules' forward methods.
    """
    def __init__(self, *args, **kwargs):
        self.leaf_modules = {}
        if 'leaf_modules' in kwargs:
            leaf_modules = kwargs.pop('leaf_modules')
            self.leaf_modules = leaf_modules
        super(LeafModuleAwareTracer, self).__init__(*args, **kwargs)

    def is_leaf_module(self, m: nn.Module, module_qualname: str) -> bool:
        if isinstance(m, tuple(self.leaf_modules)):
            return True
        return super().is_leaf_module(m, module_qualname)


class NodePathTracer(LeafModuleAwareTracer):
    """
    NodePathTracer is an FX tracer that, for each operation, also records the
    name of the Node from which the operation originated. A node name here is
    a `.` seperated path walking the hierarchy from top level module down to
    leaf operation or leaf module. The name of the top level module is not
    included as part of the node name. For example, if we trace a module whose
    forward method applies a ReLU module, the name for that node will simply
    be 'relu'.

    Some notes on the specifics:
        - Nodes are recorded to `self.node_to_qualname` which is a dictionary
          mapping a given Node object to its node name.
        - Nodes are recorded in the order which they are executed during
          tracing.
        - When a duplicate node name is encountered, a suffix of the form
          _{int} is added. The counter starts from 1.
    """
    def __init__(self, *args, **kwargs):
        super(NodePathTracer, self).__init__(*args, **kwargs)
        # Track the qualified name of the Node being traced
        self.current_module_qualname = ''
        # A map from FX Node to the qualified name\#
        # NOTE: This is loosely like the "qualified name" mentioned in the
        # torch.fx docs https://pytorch.org/docs/stable/fx.html but adapted
        # for the purposes of the torchvision feature extractor
        self.node_to_qualname = OrderedDict()

    def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs):
        """
        Override of `fx.Tracer.call_module`
        This override:
        1) Stores away the qualified name of the caller for restoration later
        2) Adds the qualified name of the caller to
           `current_module_qualname` for retrieval by `create_proxy`
        3) Once a leaf module is reached, calls `create_proxy`
        4) Restores the caller's qualified name into current_module_qualname
        """
        old_qualname = self.current_module_qualname
        try:
            module_qualname = self.path_of_module(m)
            self.current_module_qualname = module_qualname
            if not self.is_leaf_module(m, module_qualname):
                out = forward(*args, **kwargs)
                return out
            return self.create_proxy('call_module', module_qualname, args, kwargs)
        finally:
            self.current_module_qualname = old_qualname

    def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs,
                     name=None, type_expr=None, *_) -> fx.proxy.Proxy:
        """
        Override of `Tracer.create_proxy`. This override intercepts the recording
        of every operation and stores away the current traced module's qualified
        name in `node_to_qualname`
        """
        proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr)
        self.node_to_qualname[proxy.node] = self._get_node_qualname(
            self.current_module_qualname, proxy.node)
        return proxy

    def _get_node_qualname(
            self, module_qualname: str, node: fx.node.Node) -> str:
        node_qualname = module_qualname
        if node.op == 'call_module':
            # Node terminates in a leaf module so the module_qualname is a
            # complete description of the node
            for existing_qualname in reversed(self.node_to_qualname.values()):
                # Check to see if existing_qualname is of the form
                # {node_qualname} or {node_qualname}_{int}
                if re.match(rf'{node_qualname}(_[0-9]+)?$',
                            existing_qualname) is not None:
                    postfix = existing_qualname.replace(node_qualname, '')
                    if len(postfix):
                        # Existing_qualname is of the form {node_qualname}_{int}
                        next_index = int(postfix[1:]) + 1
                    else:
                        # existing_qualname is of the form {node_qualname}
                        next_index = 1
                    node_qualname += f'_{next_index}'
                    break
                pass
        else:
            # Node terminates in non- leaf module so the node name needs to be
            # appended
            if len(node_qualname) > 0:
                # Only append '.' if we are deeper than the top level module
                node_qualname += '.'
            node_qualname += str(node)
        return node_qualname


def _is_subseq(x, y):
    """Check if y is a subseqence of x
    https://stackoverflow.com/a/24017747/4391249
    """
    iter_x = iter(x)
    return all(any(x_item == y_item for x_item in iter_x) for y_item in y)


def _warn_graph_differences(
        train_tracer: NodePathTracer, eval_tracer: NodePathTracer):
    """
    Utility function for warning the user if there are differences between
    the train graph nodes and the eval graph nodes.
    """
    train_nodes = list(train_tracer.node_to_qualname.values())
    eval_nodes = list(eval_tracer.node_to_qualname.values())

    if len(train_nodes) == len(eval_nodes) and all(
            t == e for t, e in zip(train_nodes, eval_nodes)):
        return

    suggestion_msg = (
        "When choosing nodes for feature extraction, you may need to specify "
        "output nodes for train and eval mode separately.")

    if _is_subseq(train_nodes, eval_nodes):
        msg = ("NOTE: The nodes obtained by tracing the model in eval mode "
               "are a subsequence of those obtained in train mode. ")
    elif _is_subseq(eval_nodes, train_nodes):
        msg = ("NOTE: The nodes obtained by tracing the model in train mode "
               "are a subsequence of those obtained in eval mode. ")
    else:
        msg = ("The nodes obtained by tracing the model in train mode "
               "are different to those obtained in eval mode. ")
    warnings.warn(msg + suggestion_msg)


def get_graph_node_names(
        model: nn.Module, tracer_kwargs: Dict = {},
        suppress_diff_warning: bool = False) -> Tuple[List[str], List[str]]:
    """
    Dev utility to return node names in order of execution. See note on node
    names under :func:`create_feature_extractor`. Useful for seeing which node
    names are available for feature extraction. There are two reasons that
    node names can't easily be read directly from the code for a model:

        1. Not all submodules are traced through. Modules from `torch.nn` all
           fall within this category.
        2. Nodes representing the repeated application of the same operation
           or leaf module get a `_{counter}` postfix.

    The model is traced twice: once in train mode, and once in eval mode. Both
    sets of nodes are returned.

    Args:
        model (nn.Module): model for which we'd like to print node names
        tracer_kwargs (dict, optional): a dictionary of keywork arguments for
            `NodePathTracer` (they are eventually passed onto
            `torch.fx.Tracer`).
        suppress_diff_warning (bool, optional): whether to suppress a warning
            when there are discrepancies between the train and eval version of
            the graph. Defaults to False.

    Returns:
        tuple(list, list): a list of node names from tracing the model in
        train mode, and another from tracing the model in eval mode.

    Examples::

        >>> model = torchvision.models.resnet18()
        >>> train_nodes, eval_nodes = get_graph_node_names(model)
    """
    is_training = model.training
    train_tracer = NodePathTracer(**tracer_kwargs)
    train_tracer.trace(model.train())
    eval_tracer = NodePathTracer(**tracer_kwargs)
    eval_tracer.trace(model.eval())
    train_nodes = list(train_tracer.node_to_qualname.values())
    eval_nodes = list(eval_tracer.node_to_qualname.values())
    if not suppress_diff_warning:
        _warn_graph_differences(train_tracer, eval_tracer)
    # Restore training state
    model.train(is_training)
    return train_nodes, eval_nodes


class DualGraphModule(fx.GraphModule):
    """
    A derivative of `fx.GraphModule`. Differs in the following ways:
    - Requires a train and eval version of the underlying graph
    - Copies submodules according to the nodes of both train and eval graphs.
    - Calling train(mode) switches between train graph and eval graph.
    """
    def __init__(self,
                 root: torch.nn.Module,
                 train_graph: fx.Graph,
                 eval_graph: fx.Graph,
                 class_name: str = 'GraphModule'):
        """
        Args:
            root (nn.Module): module from which the copied module hierarchy is
                built
            train_graph (fx.Graph): the graph that should be used in train mode
            eval_graph (fx.Graph): the graph that should be used in eval mode
        """
        super(fx.GraphModule, self).__init__()

        self.__class__.__name__ = class_name

        self.train_graph = train_graph
        self.eval_graph = eval_graph

        # Copy all get_attr and call_module ops (indicated by BOTH train and
        # eval graphs)
        for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)):
            if node.op in ['get_attr', 'call_module']:
                assert isinstance(node.target, str)
                _copy_attr(root, self, node.target)

        # train mode by default
        self.train()
        self.graph = train_graph

        # (borrowed from fx.GraphModule):
        # Store the Tracer class responsible for creating a Graph separately as part of the
        # GraphModule state, except when the Tracer is defined in a local namespace.
        # Locally defined Tracers are not pickleable. This is needed because torch.package will
        # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
        # to re-create the Graph during deserialization.
        assert self.eval_graph._tracer_cls == self.train_graph._tracer_cls, \
            "Train mode and eval mode should use the same tracer class"
        self._tracer_cls = None
        if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__:
            self._tracer_cls = self.graph._tracer_cls

    def train(self, mode=True):
        """
        Swap out the graph depending on the selected training mode.
        NOTE this should be safe when calling model.eval() because that just
        calls this with mode == False.
        """
        # NOTE: Only set self.graph if the current graph is not the desired
        # one. This saves us from recompiling the graph where not necessary.
        if mode and not self.training:
            self.graph = self.train_graph
        elif not mode and self.training:
            self.graph = self.eval_graph
        return super().train(mode=mode)


def create_feature_extractor(
        model: nn.Module,
        return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
        train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
        eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
        tracer_kwargs: Dict = {},
        suppress_diff_warning: bool = False) -> fx.GraphModule:
    """
    Creates a new graph module that returns intermediate nodes from a given
    model as dictionary with user specified keys as strings, and the requested
    outputs as values. This is achieved by re-writing the computation graph of
    the model via FX to return the desired nodes as outputs. All unused nodes
    are removed, together with their corresponding parameters.

    A note on node specification: For the purposes of this feature extraction
    utility, a node name is specified as a `.` seperated path walking the
    hierarchy from top level module down to leaf operation or leaf module. For
    instance `blocks.5.3.bn1`. The keys of the `return_nodes` argument should
    point to either a node's name, or some truncated version of it. For
    example, one could provide `blocks.5` as a key, and the last node with
    that prefix will be selected. :func:`get_graph_node_names` is a useful
    helper function for getting a list of node names of a model.

    Not all models will be FX traceable, although with some massaging they can
    be made to cooperate. Here's a (not exhaustive) list of tips:

        - If you don't need to trace through a particular, problematic
          sub-module, turn it into a "leaf module" by passing a list of
          `leaf_modules` as one of the `tracer_kwargs` (see example below). It
          will not be traced through, but rather, the resulting graph will
          hold a reference to that module's forward method.
        - Likewise, you may turn functions into leaf functions by passing a
          list of `autowrap_functions` as one of the `tracer_kwargs` (see
          example below).
        - Some inbuilt Python functions can be problematic. For instance,
          `int` will raise an error during tracing. You may wrap them in your
          own function and then pass that in `autowrap_functions` as one of
          the `tracer_kwargs`.

    For further information on FX see the
    `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_.

    Args:
        model (nn.Module): model on which we will extract the features
        return_nodes (list or dict, optional): either a `List` or a `Dict`
            containing the names (or partial names - see note above)
            of the nodes for which the activations will be returned. If it is
            a `Dict`, the keys are the node names, and the values
            are the user-specified keys for the graph module's returned
            dictionary. If it is a `List`, it is treated as a `Dict` mapping
            node specification strings directly to output names. In the case
            that `train_return_nodes` and `eval_return_nodes` are specified,
            this should not be specified.
        train_return_nodes (list or dict, optional): similar to
            `return_nodes`. This can be used if the return nodes
            for train mode are different than those from eval mode.
            If this is specified, `eval_return_nodes` must also be specified,
            and `return_nodes` should not be specified.
        eval_return_nodes (list or dict, optional): similar to
            `return_nodes`. This can be used if the return nodes
            for train mode are different than those from eval mode.
            If this is specified, `train_return_nodes` must also be specified,
            and `return_nodes` should not be specified.
        tracer_kwargs (dict, optional): a dictionary of keywork arguments for
            `NodePathTracer` (which passes them onto it's parent class
            `torch.fx.Tracer`).
        suppress_diff_warning (bool, optional): whether to suppress a warning
            when there are discrepancies between the train and eval version of
            the graph. Defaults to False.

    Examples::

        >>> # Feature extraction with resnet
        >>> model = torchvision.models.resnet18()
        >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
        >>> model = create_feature_extractor(
        >>>     model, {'layer1': 'feat1', 'layer3': 'feat2'})
        >>> out = model(torch.rand(1, 3, 224, 224))
        >>> print([(k, v.shape) for k, v in out.items()])
        >>>     [('feat1', torch.Size([1, 64, 56, 56])),
        >>>      ('feat2', torch.Size([1, 256, 14, 14]))]

        >>> # Specifying leaf modules and leaf functions
        >>> def leaf_function(x):
        >>>     # This would raise a TypeError if traced through
        >>>     return int(x)
        >>>
        >>> class LeafModule(torch.nn.Module):
        >>>     def forward(self, x):
        >>>         # This would raise a TypeError if traced through
        >>>         int(x.shape[0])
        >>>         return torch.nn.functional.relu(x + 4)
        >>>
        >>> class MyModule(torch.nn.Module):
        >>>     def __init__(self):
        >>>         super().__init__()
        >>>         self.conv = torch.nn.Conv2d(3, 1, 3)
        >>>         self.leaf_module = LeafModule()
        >>>
        >>>     def forward(self, x):
        >>>         leaf_function(x.shape[0])
        >>>         x = self.conv(x)
        >>>         return self.leaf_module(x)
        >>>
        >>> model = create_feature_extractor(
        >>>     MyModule(), return_nodes=['leaf_module'],
        >>>     tracer_kwargs={'leaf_modules': [LeafModule],
        >>>                    'autowrap_functions': [leaf_function]})

    """
    is_training = model.training

    assert any(arg is not None for arg in [
        return_nodes, train_return_nodes, eval_return_nodes]), (
            "Either `return_nodes` or `train_return_nodes` and "
            "`eval_return_nodes` together, should be specified")

    assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), \
        ("If any of `train_return_nodes` and `eval_return_nodes` are "
         "specified, then both should be specified")

    assert ((return_nodes is None) ^ (train_return_nodes is None)), \
        ("If `train_return_nodes` and `eval_return_nodes` are specified, "
         "then both should be specified")

    # Put *_return_nodes into Dict[str, str] format
    def to_strdict(n) -> Dict[str, str]:
        if isinstance(n, list):
            return {str(i): str(i) for i in n}
        return {str(k): str(v) for k, v in n.items()}

    if train_return_nodes is None:
        return_nodes = to_strdict(return_nodes)
        train_return_nodes = deepcopy(return_nodes)
        eval_return_nodes = deepcopy(return_nodes)
    else:
        train_return_nodes = to_strdict(train_return_nodes)
        eval_return_nodes = to_strdict(eval_return_nodes)

    # Repeat the tracing and graph rewriting for train and eval mode
    tracers = {}
    graphs = {}
    mode_return_nodes: Dict[str, Dict[str, str]] = {
        'train': train_return_nodes,
        'eval': eval_return_nodes
    }
    for mode in ['train', 'eval']:
        if mode == 'train':
            model.train()
        elif mode == 'eval':
            model.eval()

        # Instantiate our NodePathTracer and use that to trace the model
        tracer = NodePathTracer(**tracer_kwargs)
        graph = tracer.trace(model)

        name = model.__class__.__name__ if isinstance(
            model, nn.Module) else model.__name__
        graph_module = fx.GraphModule(tracer.root, graph, name)

        available_nodes = list(tracer.node_to_qualname.values())
        # FIXME We don't know if we should expect this to happen
        assert len(set(available_nodes)) == len(available_nodes), \
            "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues"
        # Check that all outputs in return_nodes are present in the model
        for query in mode_return_nodes[mode].keys():
            # To check if a query is available we need to check that at least
            # one of the available names starts with it up to a .
            if not any([re.match(rf'^{query}(\.|$)', n) is not None
                        for n in available_nodes]):
                raise ValueError(
                    f"node: '{query}' is not present in model. Hint: use "
                    "`get_graph_node_names` to make sure the "
                    "`return_nodes` you specified are present. It may even "
                    "be that you need to specify `train_return_nodes` and "
                    "`eval_return_nodes` separately.")

        # Remove existing output nodes (train mode)
        orig_output_nodes = []
        for n in reversed(graph_module.graph.nodes):
            if n.op == 'output':
                orig_output_nodes.append(n)
        assert len(orig_output_nodes)
        for n in orig_output_nodes:
            graph_module.graph.erase_node(n)

        # Find nodes corresponding to return_nodes and make them into output_nodes
        nodes = [n for n in graph_module.graph.nodes]
        output_nodes = OrderedDict()
        for n in reversed(nodes):
            module_qualname = tracer.node_to_qualname.get(n)
            if module_qualname is None:
                # NOTE - Know cases where this happens:
                # - Node representing creation of a tensor constant - probably
                #   not interesting as a return node
                # - When packing outputs into a named tuple like in InceptionV3
                continue
            for query in mode_return_nodes[mode]:
                depth = query.count('.')
                if '.'.join(module_qualname.split('.')[:depth + 1]) == query:
                    output_nodes[mode_return_nodes[mode][query]] = n
                    mode_return_nodes[mode].pop(query)
                    break
        output_nodes = OrderedDict(reversed(list(output_nodes.items())))

        # And add them in the end of the graph
        with graph_module.graph.inserting_after(nodes[-1]):
            graph_module.graph.output(output_nodes)

        # Remove unused modules / parameters
        graph_module.graph.eliminate_dead_code()
        graph_module.recompile()

        # Keep track of the tracer and graph so we can choose the main one
        tracers[mode] = tracer
        graphs[mode] = graph

    # Warn user if there are any discrepancies between the graphs of the
    # train and eval modes
    if not suppress_diff_warning:
        _warn_graph_differences(tracers['train'], tracers['eval'])

    # Build the final graph module
    graph_module = DualGraphModule(
        model, graphs['train'], graphs['eval'], class_name=name)

    # Restore original training mode
    model.train(is_training)
    graph_module.train(is_training)

    return graph_module