extract_feature.py 4.45 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import functools
from collections import OrderedDict


# using wonder's beautiful simplification:
# https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427
def rgetattr(obj, attr, *args):
    def _getattr(obj, attr):
        return getattr(obj, attr, *args)

    return functools.reduce(_getattr, [obj] + attr.split('.'))


class IntermediateLayerGetter:
    def __init__(self, model, return_layers, keep_output=True):
        """Wraps a Pytorch module to get intermediate values

        Arguments:
            model {nn.module} -- The Pytorch module to call
            return_layers {dict} -- Dictionary with the selected submodules
            to return the output (format: {[current_module_name]: [desired_output_name]},
            current_module_name can be a nested submodule, e.g. submodule1.submodule2.submodule3)

        Keyword Arguments:
            keep_output {bool} -- If True model_output contains the final model's output
            in the other case model_output is None (default: {True})

        Returns:
            (mid_outputs {OrderedDict}, model_output {any}) -- mid_outputs keys are
            your desired_output_name (s) and their values are the returned tensors
            of those submodules (OrderedDict([(desired_output_name,tensor(...)), ...).
            See keep_output argument for model_output description.
            In case a submodule is called more than one time, all it's outputs are
            stored in a list.
        """
        self._model = model
        self.return_layers = return_layers
        self.keep_output = keep_output

    def __call__(self, *args, **kwargs):
        ret = OrderedDict()
        handles = []
        for name, new_name in self.return_layers.items():
            layer = rgetattr(self._model, name)

            def hook(module, input, output, new_name=new_name):
                if new_name in ret:
                    if type(ret[new_name]) is list:
                        ret[new_name].append(output)
                    else:
                        ret[new_name] = [ret[new_name], output]
                else:
                    ret[new_name] = output

            try:
                h = layer.register_forward_hook(hook)
            except AttributeError as e:
                raise AttributeError(f'Module {name} not found')
            handles.append(h)

        if self.keep_output:
            output = self._model(*args, **kwargs)
        else:
            self._model(*args, **kwargs)
            output = None

        for h in handles:
            h.remove()

        return ret, output


def main(args, config):
    from models import build_model
    import torchvision.transforms as T
    from PIL import Image

    model = build_model(config)
    checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
    model.load_state_dict(checkpoint['model'], strict=False)
    model.cuda()

    # examples:
    # return_layers = {
    #     'patch_embed': 'patch_embed',
    #     'levels.0.downsample': 'levels.0.downsample',
    #     'levels.0.blocks.0.dcn': 'levels.0.blocks.0.dcn',
    # }
    return_layers = {k: k for k in args.keys}
    mid_getter = IntermediateLayerGetter(model, return_layers=return_layers, keep_output=True)

    image = Image.open(args.img)

    transforms = T.Compose([
        T.Resize(config.DATA.IMG_SIZE),
        T.ToTensor(),
        T.Normalize(config.AUG.MEAN, config.AUG.STD)
    ])
    image = transforms(image)
    image = image.unsqueeze(0)
    image = image.cuda()

    mid_outputs, model_output = mid_getter(image)

    for k, v in mid_outputs.items():
        print(k, v.shape)

    return mid_outputs, model_output


if __name__ == '__main__':
    import argparse
    import torch
    from config import get_config

    parser = argparse.ArgumentParser('Get Intermediate Layer Output')
    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='Path to config file')
    parser.add_argument('--img', type=str, required=True, metavar="FILE", help='Path to img file')
    parser.add_argument("--keys", default=None, nargs='+', help="The intermediate layer's keys you want to save.")
    parser.add_argument('--resume', help='resume from checkpoint')
    parser.add_argument('--save', action='store_true', help='Save the results.')
    args = parser.parse_args()
    config = get_config(args)

    mid_outputs, model_output = main(args, config)

    if args.save:
        torch.save(mid_outputs, args.img[:-3] + '.pth')