deepten.py 2.25 KB
Newer Older
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree 
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.autograd import Variable

import encoding
Hang Zhang's avatar
Hang Zhang committed
17
import encoding.models.resnet as resnet
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
18
19

class Net(nn.Module):
Hang Zhang's avatar
Hang Zhang committed
20
21
    def __init__(self, args):
        nclass=args.nclass
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
22
        super(Net, self).__init__()
Hang Zhang's avatar
Hang Zhang committed
23
        self.backbone = args.backbone
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
24
        # copying modules from pretrained models
Hang Zhang's avatar
Hang Zhang committed
25
        if self.backbone == 'resnet50':
Hang Zhang's avatar
Hang Zhang committed
26
            self.pretrained = resnet.resnet50(pretrained=True, dilated=False)
Hang Zhang's avatar
Hang Zhang committed
27
        elif self.backbone == 'resnet101':
Hang Zhang's avatar
Hang Zhang committed
28
            self.pretrained = resnet.resnet101(pretrained=True, dilated=False)
Hang Zhang's avatar
Hang Zhang committed
29
        elif self.backbone == 'resnet152':
Hang Zhang's avatar
Hang Zhang committed
30
            self.pretrained = resnet.resnet152(pretrained=True, dilated=False)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
31
        else:
Hang Zhang's avatar
Hang Zhang committed
32
            raise RuntimeError('unknown backbone: {}'.format(self.backbone))
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        n_codes = 32
        self.head = nn.Sequential(
            nn.Conv2d(2048, 128, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            encoding.nn.Encoding(D=128,K=n_codes),
            encoding.nn.View(-1, 128*n_codes),
            encoding.nn.Normalize(),
            nn.Linear(128*n_codes, nclass),
        )

    def forward(self, x):
        if isinstance(x, Variable):
            _, _, h, w = x.size()
        elif isinstance(x, tuple) or isinstance(x, list):
            var_input = x 
            while not isinstance(var_input, Variable):
                var_input = var_input[0]
            _, _, h, w = var_input.size()
        else:
            raise RuntimeError('unknown input type: ', type(x))
Hang Zhang's avatar
Hang Zhang committed
54
55
56
57
58
59
60
61
        x = self.pretrained.conv1(x)
        x = self.pretrained.bn1(x)
        x = self.pretrained.relu(x)
        x = self.pretrained.maxpool(x)
        x = self.pretrained.layer1(x)
        x = self.pretrained.layer2(x)
        x = self.pretrained.layer3(x)
        x = self.pretrained.layer4(x)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
62
63
        return self.head(x)