extract_feature.py 4.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import functools
from collections import OrderedDict


# using wonder's beautiful simplification:
# https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427
def rgetattr(obj, attr, *args):
    def _getattr(obj, attr):
        return getattr(obj, attr, *args)

    return functools.reduce(_getattr, [obj] + attr.split('.'))


class IntermediateLayerGetter:
    def __init__(self, model, return_layers, keep_output=True):
        """Wraps a Pytorch module to get intermediate values

        Arguments:
            model {nn.module} -- The Pytorch module to call
            return_layers {dict} -- Dictionary with the selected submodules
            to return the output (format: {[current_module_name]: [desired_output_name]},
            current_module_name can be a nested submodule, e.g. submodule1.submodule2.submodule3)

        Keyword Arguments:
            keep_output {bool} -- If True model_output contains the final model's output
            in the other case model_output is None (default: {True})

        Returns:
            (mid_outputs {OrderedDict}, model_output {any}) -- mid_outputs keys are
            your desired_output_name (s) and their values are the returned tensors
            of those submodules (OrderedDict([(desired_output_name,tensor(...)), ...).
            See keep_output argument for model_output description.
            In case a submodule is called more than one time, all it's outputs are
            stored in a list.
        """
        self._model = model
        self.return_layers = return_layers
        self.keep_output = keep_output

    def __call__(self, *args, **kwargs):
        ret = OrderedDict()
        handles = []
        for name, new_name in self.return_layers.items():
            layer = rgetattr(self._model, name)

            def hook(module, input, output, new_name=new_name):
                if new_name in ret:
                    if type(ret[new_name]) is list:
                        ret[new_name].append(output)
                    else:
                        ret[new_name] = [ret[new_name], output]
                else:
                    ret[new_name] = output

            try:
                h = layer.register_forward_hook(hook)
            except AttributeError as e:
                raise AttributeError(f'Module {name} not found')
            handles.append(h)

        if self.keep_output:
            output = self._model(*args, **kwargs)
        else:
            self._model(*args, **kwargs)
            output = None

        for h in handles:
            h.remove()

        return ret, output


def main(args, config):
    import torchvision.transforms as T
zhe chen's avatar
zhe chen committed
75
    from models import build_model
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    from PIL import Image

    model = build_model(config)
    checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
    model.load_state_dict(checkpoint['model'], strict=False)
    model.cuda()

    # examples:
    # return_layers = {
    #     'patch_embed': 'patch_embed',
    #     'levels.0.downsample': 'levels.0.downsample',
    #     'levels.0.blocks.0.dcn': 'levels.0.blocks.0.dcn',
    # }
    return_layers = {k: k for k in args.keys}
    mid_getter = IntermediateLayerGetter(model, return_layers=return_layers, keep_output=True)

    image = Image.open(args.img)

    transforms = T.Compose([
        T.Resize(config.DATA.IMG_SIZE),
        T.ToTensor(),
        T.Normalize(config.AUG.MEAN, config.AUG.STD)
    ])
    image = transforms(image)
    image = image.unsqueeze(0)
    image = image.cuda()

    mid_outputs, model_output = mid_getter(image)

    for k, v in mid_outputs.items():
        print(k, v.shape)

    return mid_outputs, model_output


if __name__ == '__main__':
    import argparse
zhe chen's avatar
zhe chen committed
113

114
115
116
117
    import torch
    from config import get_config

    parser = argparse.ArgumentParser('Get Intermediate Layer Output')
zhe chen's avatar
zhe chen committed
118
119
120
    parser.add_argument('--cfg', type=str, required=True, metavar='FILE', help='Path to config file')
    parser.add_argument('--img', type=str, required=True, metavar='FILE', help='Path to img file')
    parser.add_argument('--keys', default=None, nargs='+', help="The intermediate layer's keys you want to save.")
121
122
123
124
125
126
127
128
    parser.add_argument('--resume', help='resume from checkpoint')
    parser.add_argument('--save', action='store_true', help='Save the results.')
    args = parser.parse_args()
    config = get_config(args)

    mid_outputs, model_output = main(args, config)

    if args.save:
zhe chen's avatar
zhe chen committed
129
        torch.save(mid_outputs, args.img[:-3] + '.pth')