##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ## Created by: Hang Zhang ## ECE Department, Rutgers University ## Email: zhang.hang@rutgers.edu ## Copyright (c) 2017 ## ## This source code is licensed under the MIT-style license found in the ## LICENSE file in the root directory of this source tree ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import torch from torch.autograd import Variable import torch.nn as nn from torch.autograd import Variable import encoding import torchvision.models as resnet class Net(nn.Module): def __init__(self, nclass=23, aux=False, backbone='resnet50'): super(Net, self).__init__() self.backbone = backbone # copying modules from pretrained models if backbone == 'resnet50': self.pretrained = resnet.resnet50(pretrained=True) elif backbone == 'resnet101': self.pretrained = resnet.resnet101(pretrained=True) elif backbone == 'resnet152': self.pretrained = resnet.resnet152(pretrained=True) else: raise RuntimeError('unknown backbone: {}'.format(backbone)) self.aux = aux n_codes = 32 self.head = nn.Sequential( nn.Conv2d(2048, 128, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), encoding.nn.Encoding(D=128,K=n_codes), encoding.nn.View(-1, 128*n_codes), encoding.nn.Normalize(), nn.Linear(128*n_codes, nclass), ) def forward(self, x): if isinstance(x, Variable): _, _, h, w = x.size() elif isinstance(x, tuple) or isinstance(x, list): var_input = x while not isinstance(var_input, Variable): var_input = var_input[0] _, _, h, w = var_input.size() else: raise RuntimeError('unknown input type: ', type(x)) if self.backbone == 'resnet50' or self.backbone == 'resnet101' \ or self.backbone == 'resnet152': # pre-trained ResNet feature x = self.pretrained.conv1(x) x = self.pretrained.bn1(x) x = self.pretrained.relu(x) x = self.pretrained.maxpool(x) x = self.pretrained.layer1(x) x = self.pretrained.layer2(x) x = self.pretrained.layer3(x) x = self.pretrained.layer4(x) else: x = self.pretrained(x) return self.head(x) def test(): net = Net(nclass=23).cuda() print(net) x = Variable(torch.randn(1,3,224,224)).cuda() y = net(x) print(y) params = net.parameters() sum = 0 for param in params: sum += param.nelement() print('Total params:', sum) if __name__ == "__main__": test()