"vscode:/vscode.git/clone" did not exist on "0a9bfc20ab8e6d2744ae4588dab34cf7ed3a5980"
FeatureExtractor.py 715 Bytes
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch.nn as nn
from IPython import embed

class FeatureExtractor(nn.Module):
    def __init__(self,submodule,extracted_layers):
        super(FeatureExtractor,self).__init__()
        self.submodule = submodule
        self.extracted_layers = extracted_layers

    def forward(self, x):
        outputs = []
        for name, module in self.submodule._modules.items():
            if name is "classfier":
                x = x.view(x.size(0),-1)
            if name is "base":
                for block_name, cnn_block in module._modules.items():
                    x = cnn_block(x)
                    if block_name in self.extracted_layers:
                        outputs.append(x)
        return outputs