from __future__ import absolute_import import torch from torch import nn from torch.nn import functional as F import torchvision from aligned.HorizontalMaxPool2D import HorizontalMaxPool2d __all__ = ['DenseNet121'] class DenseNet121(nn.Module): def __init__(self, num_classes, loss={'softmax'}, aligned=False,**kwargs): super(DenseNet121, self).__init__() self.loss = loss densenet121 = torchvision.models.densenet121(pretrained=True) self.base = densenet121.features self.classifier = nn.Linear(1024, num_classes) self.feat_dim = 1024 # feature dimension self.aligned = aligned self.horizon_pool = HorizontalMaxPool2d() if self.aligned: self.bn = nn.BatchNorm2d(1024) self.relu = nn.ReLU(inplace=True) self.conv1 = nn.Conv2d(1024, 128, kernel_size=1, stride=1, padding=0, bias=True) def forward(self, x): x = self.base(x) if not self.training: lf = self.horizon_pool(x) if self.aligned: lf = self.bn(x) lf = self.relu(lf) lf = self.horizon_pool(lf) lf = self.conv1(lf) if self.aligned or not self.training: lf = lf.view(lf.size()[0:3]) lf = lf / torch.pow(lf, 2).sum(dim=1, keepdim=True).clamp(min=1e-12).sqrt() x = F.avg_pool2d(x, x.size()[2:]) f = x.view(x.size(0), -1) # f = 1. * f / (torch.norm(f, 2, dim=-1, keepdim=True).expand_as(f) + 1e-12) if not self.training: return f, lf y = self.classifier(f) if self.loss == {'softmax'}: return y elif self.loss == {'metric'}: if self.aligned: return f, lf return f elif self.loss == {'softmax', 'metric'}: if self.aligned: return y, f, lf return y, f else: raise KeyError("Unsupported loss: {}".format(self.loss))