"vscode:/vscode.git/clone" did not exist on "bf2a70872ee4b5da52e9b3a497005c8372484273"
deepten.py 2.2 KB
Newer Older
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree 
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.autograd import Variable

import encoding
import torchvision.models as resnet

class Net(nn.Module):
Hang Zhang's avatar
Hang Zhang committed
20
21
    def __init__(self, args):
        nclass=args.nclass
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
22
        super(Net, self).__init__()
Hang Zhang's avatar
Hang Zhang committed
23
        self.backbone = args.backbone
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
24
        # copying modules from pretrained models
Hang Zhang's avatar
Hang Zhang committed
25
        if self.backbone == 'resnet50':
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
26
            self.pretrained = resnet.resnet50(pretrained=True)
Hang Zhang's avatar
Hang Zhang committed
27
        elif self.backbone == 'resnet101':
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
28
            self.pretrained = resnet.resnet101(pretrained=True)
Hang Zhang's avatar
Hang Zhang committed
29
        elif self.backbone == 'resnet152':
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
30
31
            self.pretrained = resnet.resnet152(pretrained=True)
        else:
Hang Zhang's avatar
Hang Zhang committed
32
            raise RuntimeError('unknown backbone: {}'.format(self.backbone))
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        n_codes = 32
        self.head = nn.Sequential(
            nn.Conv2d(2048, 128, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            encoding.nn.Encoding(D=128,K=n_codes),
            encoding.nn.View(-1, 128*n_codes),
            encoding.nn.Normalize(),
            nn.Linear(128*n_codes, nclass),
        )

    def forward(self, x):
        if isinstance(x, Variable):
            _, _, h, w = x.size()
        elif isinstance(x, tuple) or isinstance(x, list):
            var_input = x 
            while not isinstance(var_input, Variable):
                var_input = var_input[0]
            _, _, h, w = var_input.size()
        else:
            raise RuntimeError('unknown input type: ', type(x))
Hang Zhang's avatar
Hang Zhang committed
54
55
56
57
58
59
60
61
        x = self.pretrained.conv1(x)
        x = self.pretrained.bn1(x)
        x = self.pretrained.relu(x)
        x = self.pretrained.maxpool(x)
        x = self.pretrained.layer1(x)
        x = self.pretrained.layer2(x)
        x = self.pretrained.layer3(x)
        x = self.pretrained.layer4(x)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
62
63
        return self.head(x)