_utils.py 3.52 KB
Newer Older
1
2
3
4
from collections import OrderedDict

import torch
from torch import nn
eellison's avatar
eellison committed
5
from torch.jit.annotations import Dict
6
7


eellison's avatar
eellison committed
8
class IntermediateLayerGetter(nn.Module):
9
10
11
12
13
14
    """
    Module wrapper that returns intermediate layers from a model

    It has a strong assumption that the modules have been registered
    into the model in the same order as they are used.
    This means that one should **not** reuse the same nn.Module
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    twice in the forward if you want this to work.

    Additionally, it is only able to query submodules that are directly
    assigned to the model. So if `model` is passed, `model.feature1` can
    be returned, but not `model.feature1.layer2`.

    Arguments:
        model (nn.Module): model on which we will extract the features
        return_layers (Dict[name, new_name]): a dict containing the names
            of the modules for which the activations will be returned as
            the key of the dict, and the value of the dict is the name
            of the returned activation (which the user can specify).

    Examples::

        >>> m = torchvision.models.resnet18(pretrained=True)
        >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
        >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
        >>>     {'layer1': 'feat1', 'layer3': 'feat2'})
        >>> out = new_m(torch.rand(1, 3, 224, 224))
        >>> print([(k, v.shape) for k, v in out.items()])
        >>>     [('feat1', torch.Size([1, 64, 56, 56])),
        >>>      ('feat2', torch.Size([1, 256, 14, 14]))]
38
    """
eellison's avatar
eellison committed
39
40
41
42
43
44
    _version = 2
    __constants__ = ['layers']
    __annotations__ = {
        "return_layers": Dict[str, str],
    }

45
46
47
    def __init__(self, model, return_layers):
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError("return_layers are not present in model")
eellison's avatar
eellison committed
48
        super(IntermediateLayerGetter, self).__init__()
49
50
51
52
53
54
55
56
57
58
59

        orig_return_layers = return_layers
        return_layers = {k: v for k, v in return_layers.items()}
        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break

eellison's avatar
eellison committed
60
        self.layers = nn.ModuleDict(layers)
61
62
63
64
        self.return_layers = orig_return_layers

    def forward(self, x):
        out = OrderedDict()
eellison's avatar
eellison committed
65
        for name, module in self.layers.items():
66
67
68
69
70
            x = module(x)
            if name in self.return_layers:
                out_name = self.return_layers[name]
                out[out_name] = x
        return out
eellison's avatar
eellison committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

    @torch.jit.ignore
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        version = local_metadata.get('version', None)
        if (version is None or version < 2):
            # now we have a new nesting level for torchscript support
            for new_key in self.state_dict().keys():
                # remove prefix "layers."
                old_key = new_key[len("layers."):]
                old_key = prefix + old_key
                new_key = prefix + new_key
                if old_key in state_dict:
                    value = state_dict[old_key]
                    del state_dict[old_key]
                    state_dict[new_key] = value
        super(IntermediateLayerGetter, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)