customize.py 6.42 KB
Newer Older
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
1
2
3
4
5
6
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
Hang Zhang's avatar
sync BN  
Hang Zhang committed
8
## LICENSE file in the root directory of this source tree
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
9
10
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Hang Zhang's avatar
sync BN  
Hang Zhang committed
11
"""Encoding Custermized NN Module"""
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
12
import torch
Hang Zhang's avatar
Hang Zhang committed
13
import torch.nn as nn
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
14
from torch.nn import functional as F
Zhang's avatar
v0.4.2  
Zhang committed
15
from torch.autograd import Variable
16
17
from .splat import SplAtConv2d
from .rectify import RFConv2d
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
18

Zhang's avatar
v0.4.2  
Zhang committed
19
20
torch_ver = torch.__version__[:3]

21
__all__ = ['ConvBnAct', 'GlobalAvgPool2d', 'GramMatrix',
Hang Zhang's avatar
Hang Zhang committed
22
23
           'View', 'Sum', 'Mean', 'Normalize', 'ConcurrentModule',
           'PyramidPooling']
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class ConvBnAct(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, radix=0, groups=1,
                 bias=True, padding_mode='zeros',
                 rectify=False, rectify_avg=False, act=True,
                 norm_layer=nn.BatchNorm2d):
        super().__init__()
        if radix > 0:
            conv_layer = SplAtConv2d
            conv_kwargs = {'radix': radix, 'rectify': rectify, 'rectify_avg': rectify_avg, 'norm_layer': norm_layer}
        else:
            conv_layer = RFConv2d if rectify else nn.Conv2d
            conv_kwargs = {'average_mode': rectify_avg} if rectify else {}
        self.add_module("conv", conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
                                           padding=padding, dilation=dilation, groups=groups, bias=bias,
                                           padding_mode=padding_mode, **conv_kwargs))
        self.add_module("bn", nn.BatchNorm2d(out_channels))
        if act:
            self.add_module("relu", nn.ReLU())


Hang Zhang's avatar
Hang Zhang committed
46
47
48
49
50
51
52
53
class GlobalAvgPool2d(nn.Module):
    def __init__(self):
        """Global average pooling over the input's spatial dimensions"""
        super(GlobalAvgPool2d, self).__init__()

    def forward(self, inputs):
        return F.adaptive_avg_pool2d(inputs, 1).view(inputs.size(0), -1)

54

Hang Zhang's avatar
Hang Zhang committed
55
class GramMatrix(nn.Module):
Hang Zhang's avatar
path  
Hang Zhang committed
56
    r""" Gram Matrix for a 4D convolutional featuremaps as a mini-batch
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
57
58

    .. math::
Hang Zhang's avatar
path  
Hang Zhang committed
59
        \mathcal{G} = \sum_{h=1}^{H_i}\sum_{w=1}^{W_i} \mathcal{F}_{h,w}\mathcal{F}_{h,w}^T
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
60
    """
Hang Zhang's avatar
path  
Hang Zhang committed
61
62
63
64
65
66
    def forward(self, y):
        (b, ch, h, w) = y.size()
        features = y.view(b, ch, w * h)
        features_t = features.transpose(1, 2)
        gram = features.bmm(features_t) / (ch * h * w)
        return gram
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
67

68

Hang Zhang's avatar
Hang Zhang committed
69
class View(nn.Module):
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
70
71
72
73
74
75
76
77
78
79
80
    """Reshape the input into different size, an inplace operator, support
    SelfParallel mode.
    """
    def __init__(self, *args):
        super(View, self).__init__()
        if len(args) == 1 and isinstance(args[0], torch.Size):
            self.size = args[0]
        else:
            self.size = torch.Size(args)

    def forward(self, input):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
81
        return input.view(self.size)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
82
83


Hang Zhang's avatar
Hang Zhang committed
84
class Sum(nn.Module):
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
85
86
87
88
89
90
    def __init__(self, dim, keep_dim=False):
        super(Sum, self).__init__()
        self.dim = dim
        self.keep_dim = keep_dim

    def forward(self, input):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
91
        return input.sum(self.dim, self.keep_dim)
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
92
93


Hang Zhang's avatar
Hang Zhang committed
94
class Mean(nn.Module):
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
95
96
97
98
99
100
    def __init__(self, dim, keep_dim=False):
        super(Mean, self).__init__()
        self.dim = dim
        self.keep_dim = keep_dim

    def forward(self, input):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
101
        return input.mean(self.dim, self.keep_dim)
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
102
103


Hang Zhang's avatar
Hang Zhang committed
104
class Normalize(nn.Module):
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    r"""Performs :math:`L_p` normalization of inputs over specified dimension.

    Does:

    .. math::
        v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}

    for each subtensor v over dimension dim of input. Each subtensor is
    flattened into a vector, i.e. :math:`\lVert v \rVert_p` is not a matrix
    norm.

    With default arguments normalizes over the second dimension with Euclidean
    norm.

    Args:
        p (float): the exponent value in the norm formulation. Default: 2
        dim (int): the dimension to reduce. Default: 1
    """
    def __init__(self, p=2, dim=1):
        super(Normalize, self).__init__()
        self.p = p
Hang Zhang's avatar
sync BN  
Hang Zhang committed
126
        self.dim = dim
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
127
128

    def forward(self, x):
Hang Zhang's avatar
Hang Zhang committed
129
130
        return F.normalize(x, self.p, self.dim, eps=1e-8)

Hang Zhang's avatar
Hang Zhang committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
class ConcurrentModule(nn.ModuleList):
    r"""Feed to a list of modules concurrently. 
    The outputs of the layers are concatenated at channel dimension.

    Args:
        modules (iterable, optional): an iterable of modules to add
    """
    def __init__(self, modules=None):
        super(ConcurrentModule, self).__init__(modules)

    def forward(self, x):
        outputs = []
        for layer in self:
            outputs.append(layer(x))
        return torch.cat(outputs, 1)
Hang Zhang's avatar
Hang Zhang committed
146

Hang Zhang's avatar
Hang Zhang committed
147
class PyramidPooling(nn.Module):
Hang Zhang's avatar
Hang Zhang committed
148
149
150
151
152
153
    """
    Reference:
        Zhao, Hengshuang, et al. *"Pyramid scene parsing network."*
    """
    def __init__(self, in_channels, norm_layer, up_kwargs):
        super(PyramidPooling, self).__init__()
Hang Zhang's avatar
Hang Zhang committed
154
155
156
157
        self.pool1 = nn.AdaptiveAvgPool2d(1)
        self.pool2 = nn.AdaptiveAvgPool2d(2)
        self.pool3 = nn.AdaptiveAvgPool2d(3)
        self.pool4 = nn.AdaptiveAvgPool2d(6)
Hang Zhang's avatar
Hang Zhang committed
158
159

        out_channels = int(in_channels/4)
Hang Zhang's avatar
Hang Zhang committed
160
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
Hang Zhang's avatar
Hang Zhang committed
161
                                norm_layer(out_channels),
Hang Zhang's avatar
Hang Zhang committed
162
163
                                nn.ReLU(True))
        self.conv2 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
Hang Zhang's avatar
Hang Zhang committed
164
                                norm_layer(out_channels),
Hang Zhang's avatar
Hang Zhang committed
165
166
                                nn.ReLU(True))
        self.conv3 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
Hang Zhang's avatar
Hang Zhang committed
167
                                norm_layer(out_channels),
Hang Zhang's avatar
Hang Zhang committed
168
169
                                nn.ReLU(True))
        self.conv4 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
Hang Zhang's avatar
Hang Zhang committed
170
                                norm_layer(out_channels),
Hang Zhang's avatar
Hang Zhang committed
171
172
                                nn.ReLU(True))
        # bilinear interpolate options
Hang Zhang's avatar
Hang Zhang committed
173
174
175
176
        self._up_kwargs = up_kwargs

    def forward(self, x):
        _, _, h, w = x.size()
Hang Zhang's avatar
Hang Zhang committed
177
178
179
180
        feat1 = F.interpolate(self.conv1(self.pool1(x)), (h, w), **self._up_kwargs)
        feat2 = F.interpolate(self.conv2(self.pool2(x)), (h, w), **self._up_kwargs)
        feat3 = F.interpolate(self.conv3(self.pool3(x)), (h, w), **self._up_kwargs)
        feat4 = F.interpolate(self.conv4(self.pool4(x)), (h, w), **self._up_kwargs)
Hang Zhang's avatar
Hang Zhang committed
181
        return torch.cat((x, feat1, feat2, feat3, feat4), 1)