encoding.py 11.9 KB
Newer Older
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
1
2
3
4
5
6
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
Hang Zhang's avatar
sync BN  
Hang Zhang committed
8
## LICENSE file in the root directory of this source tree
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
9
10
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Hang Zhang's avatar
sync BN  
Hang Zhang committed
11
"""Encoding Package Core NN Modules."""
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
12
import torch
Hang Zhang's avatar
path  
Hang Zhang committed
13
from torch.nn import Module, Parameter
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
14
import torch.nn.functional as F
Hang Zhang's avatar
sync BN  
Hang Zhang committed
15
16
from torch.autograd import Variable
from torch.nn.modules.utils import _pair
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
17

Hang Zhang's avatar
Hang Zhang committed
18
from ..functions import scaledL2, aggregate, pairwise_cosine
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
19

Zhang's avatar
v0.4.2  
Zhang committed
20
__all__ = ['Encoding', 'EncodingDrop', 'Inspiration', 'UpsampleConv2d']
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
21

Hang Zhang's avatar
path  
Hang Zhang committed
22
class Encoding(Module):
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
23
    r"""
Hang Zhang's avatar
sync BN  
Hang Zhang committed
24
    Encoding Layer: a learnable residual encoder.
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
25

Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
26
    .. image:: _static/img/cvpr17.svg
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
27
28
29
        :width: 50%
        :align: center

Hang Zhang's avatar
sync BN  
Hang Zhang committed
30
31
32
33
34
35
36
    Encoding Layer accpets 3D or 4D inputs.
    It considers an input featuremaps with the shape of :math:`C\times H\times W`
    as a set of C-dimentional input features :math:`X=\{x_1, ...x_N\}`, where N is total number
    of features given by :math:`H\times W`, which learns an inherent codebook
    :math:`D=\{d_1,...d_K\}` and a set of smoothing factor of visual centers
    :math:`S=\{s_1,...s_K\}`. Encoding Layer outputs the residuals with soft-assignment weights
    :math:`e_k=\sum_{i=1}^Ne_{ik}`, where
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
37

Hang Zhang's avatar
sync BN  
Hang Zhang committed
38
    .. math::
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
39

Hang Zhang's avatar
sync BN  
Hang Zhang committed
40
        e_{ik} = \frac{exp(-s_k\|r_{ik}\|^2)}{\sum_{j=1}^K exp(-s_j\|r_{ij}\|^2)} r_{ik}
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
41

Hang Zhang's avatar
sync BN  
Hang Zhang committed
42
43
    and the residuals are given by :math:`r_{ik} = x_i - d_k`. The output encoders are
    :math:`E=\{e_1,...e_K\}`.
Hang Zhang's avatar
path  
Hang Zhang committed
44

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
45
46
47
48
49
    Args:
        D: dimention of the features or feature channels
        K: number of codeswords

    Shape:
Hang Zhang's avatar
sync BN  
Hang Zhang committed
50
51
52
        - Input: :math:`X\in\mathcal{R}^{B\times N\times D}` or
          :math:`\mathcal{R}^{B\times D\times H\times W}` (where :math:`B` is batch,
          :math:`N` is total number of features or :math:`H\times W`.)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
53
        - Output: :math:`E\in\mathcal{R}^{B\times K\times D}`
Hang Zhang's avatar
sync BN  
Hang Zhang committed
54

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
55
56
57
58
    Attributes:
        codewords (Tensor): the learnable codewords of shape (:math:`K\times D`)
        scale (Tensor): the learnable scale factor of visual centers

Hang Zhang's avatar
sync BN  
Hang Zhang committed
59
60
61
62
63
64
65
66
    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*

        Hang Zhang, Jia Xue, and Kristin Dana. "Deep TEN: Texture Encoding Network."
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2017*

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
67
68
69
70
    Examples:
        >>> import encoding
        >>> import torch
        >>> import torch.nn.functional as F
Hang Zhang's avatar
path  
Hang Zhang committed
71
        >>> from torch.autograd import Variable
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
72
73
74
75
76
77
78
79
80
        >>> B,C,H,W,K = 2,3,4,5,6
        >>> X = Variable(torch.cuda.DoubleTensor(B,C,H,W).uniform_(-0.5,0.5), requires_grad=True)
        >>> layer = encoding.Encoding(C,K).double().cuda()
        >>> E = layer(X)
    """
    def __init__(self, D, K):
        super(Encoding, self).__init__()
        # init codewords and smoothing factor
        self.D, self.K = D, K
Hang Zhang's avatar
sync BN  
Hang Zhang committed
81
82
        self.codewords = Parameter(torch.Tensor(K, D), requires_grad=True)
        self.scale = Parameter(torch.Tensor(K), requires_grad=True)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
83
        self.reset_params()
Hang Zhang's avatar
sync BN  
Hang Zhang committed
84

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
85
86
87
    def reset_params(self):
        std1 = 1./((self.K*self.D)**(1/2))
        self.codewords.data.uniform_(-std1, std1)
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
88
        self.scale.data.uniform_(-1, 0)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
89
90
91

    def forward(self, X):
        # input X is a 4D tensor
Hang Zhang's avatar
sync BN  
Hang Zhang committed
92
        assert(X.size(1) == self.D)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
93
94
        if X.dim() == 3:
            # BxDxN
Hang Zhang's avatar
sync BN  
Hang Zhang committed
95
96
            B, D = X.size(0), self.D
            X = X.transpose(1, 2).contiguous()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
97
98
        elif X.dim() == 4:
            # BxDxHxW
Hang Zhang's avatar
sync BN  
Hang Zhang committed
99
100
            B, D = X.size(0), self.D
            X = X.view(B, D, -1).transpose(1, 2).contiguous()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
101
102
        else:
            raise RuntimeError('Encoding Layer unknown input dims!')
Hang Zhang's avatar
Hang Zhang committed
103
104
        # assignment weights NxKxD
        A = F.softmax(scaledL2(X, self.codewords, self.scale), dim=1)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
105
106
107
108
109
110
111
112
113
        # aggregate
        E = aggregate(A, X, self.codewords)
        return E

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'N x ' + str(self.D) + '=>' + str(self.K) + 'x' \
            + str(self.D) + ')'

Hang Zhang's avatar
Hang Zhang committed
114

Hang Zhang's avatar
v0.2.0  
Hang Zhang committed
115
class EncodingDrop(Module):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
116
117
    r"""Dropout regularized Encoding Layer.
    """
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
118
    def __init__(self, D, K):
Hang Zhang's avatar
v0.2.0  
Hang Zhang committed
119
        super(EncodingDrop, self).__init__()
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
120
121
        # init codewords and smoothing factor
        self.D, self.K = D, K
Hang Zhang's avatar
sync BN  
Hang Zhang committed
122
123
        self.codewords = Parameter(torch.Tensor(K, D), requires_grad=True)
        self.scale = Parameter(torch.Tensor(K), requires_grad=True)
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
124
        self.reset_params()
Hang Zhang's avatar
sync BN  
Hang Zhang committed
125

Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
126
127
128
129
130
    def reset_params(self):
        std1 = 1./((self.K*self.D)**(1/2))
        self.codewords.data.uniform_(-std1, std1)
        self.scale.data.uniform_(-1, 0)

Hang Zhang's avatar
v0.2.0  
Hang Zhang committed
131
    def _drop(self):
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
132
133
134
135
136
137
138
        if self.training:
            self.scale.data.uniform_(-1, 0)
        else:
            self.scale.data.zero_().add_(-0.5)

    def forward(self, X):
        # input X is a 4D tensor
Hang Zhang's avatar
sync BN  
Hang Zhang committed
139
        assert(X.size(1) == self.D)
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
140
141
        if X.dim() == 3:
            # BxDxN
Hang Zhang's avatar
sync BN  
Hang Zhang committed
142
143
            B, D = X.size(0), self.D
            X = X.transpose(1, 2).contiguous()
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
144
145
        elif X.dim() == 4:
            # BxDxHxW
Hang Zhang's avatar
sync BN  
Hang Zhang committed
146
147
            B, D = X.size(0), self.D
            X = X.view(B, D, -1).transpose(1, 2).contiguous()
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
148
149
        else:
            raise RuntimeError('Encoding Layer unknown input dims!')
Hang Zhang's avatar
v0.2.0  
Hang Zhang committed
150
        self._drop()
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
151
        # assignment weights
Hang Zhang's avatar
Hang Zhang committed
152
        A = F.softmax(scaledL2(X, self.codewords, self.scale), dim=1)
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
153
154
        # aggregate
        E = aggregate(A, X, self.codewords)
Hang Zhang's avatar
v0.2.0  
Hang Zhang committed
155
        self._drop()
Hang Zhang's avatar
v0.1.0  
Hang Zhang committed
156
157
158
159
160
161
162
        return E

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'N x ' + str(self.D) + '=>' + str(self.K) + 'x' \
            + str(self.D) + ')'

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
163

Hang Zhang's avatar
path  
Hang Zhang committed
164
class Inspiration(Module):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
165
166
167
168
169
    r"""
    Inspiration Layer (CoMatch Layer) enables the multi-style transfer in feed-forward
    network, which learns to match the target feature statistics during the training.
    This module is differentialble and can be inserted in standard feed-forward network
    to be learned directly from the loss function without additional supervision.
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
170
171
172
173

    .. math::
        Y = \phi^{-1}[\phi(\mathcal{F}^T)W\mathcal{G}]

Hang Zhang's avatar
sync BN  
Hang Zhang committed
174
    Please see the `example of MSG-Net <./experiments/style.html>`_
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
175
176
177
    training multi-style generative network for real-time transfer.

    Reference:
Hang Zhang's avatar
sync BN  
Hang Zhang committed
178
179
        Hang Zhang and Kristin Dana. "Multi-style Generative Network for Real-time Transfer."
        *arXiv preprint arXiv:1703.06953 (2017)*
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
180
181
182
183
    """
    def __init__(self, C, B=1):
        super(Inspiration, self).__init__()
        # B is equal to 1 or input mini_batch
Hang Zhang's avatar
sync BN  
Hang Zhang committed
184
        self.weight = Parameter(torch.Tensor(1, C, C), requires_grad=True)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
185
        # non-parameter buffer
Hang Zhang's avatar
sync BN  
Hang Zhang committed
186
        self.G = Variable(torch.Tensor(B, C, C), requires_grad=True)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
187
188
189
190
191
192
193
194
195
196
197
        self.C = C
        self.reset_parameters()

    def reset_parameters(self):
        self.weight.data.uniform_(0.0, 0.02)

    def setTarget(self, target):
        self.G = target

    def forward(self, X):
        # input X is a 3D feature map
Hang Zhang's avatar
sync BN  
Hang Zhang committed
198
199
200
        self.P = torch.bmm(self.weight.expand_as(self.G), self.G)
        return torch.bmm(self.P.transpose(1, 2).expand(X.size(0), self.C, self.C),
                         X.view(X.size(0), X.size(1), -1)).view_as(X)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
201
202
203
204
205
206

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'N x ' + str(self.C) + ')'


Hang Zhang's avatar
path  
Hang Zhang committed
207
208
class UpsampleConv2d(Module):
    r"""
Hang Zhang's avatar
sync BN  
Hang Zhang committed
209
210
211
    To avoid the checkerboard artifacts of standard Fractionally-strided Convolution,
    we adapt an integer stride convolution but producing a :math:`2\times 2` outputs for
    each convolutional window.
Hang Zhang's avatar
path  
Hang Zhang committed
212
213
214
215
216
217

    .. image:: _static/img/upconv.png
        :width: 50%
        :align: center

    Reference:
Hang Zhang's avatar
sync BN  
Hang Zhang committed
218
219
        Hang Zhang and Kristin Dana. "Multi-style Generative Network for Real-time Transfer."
        *arXiv preprint arXiv:1703.06953 (2017)*
Hang Zhang's avatar
path  
Hang Zhang committed
220
221
222
223
224
225
226

    Args:
        in_channels (int): Number of channels in the input image
        out_channels (int): Number of channels produced by the convolution
        kernel_size (int or tuple): Size of the convolving kernel
        stride (int or tuple, optional): Stride of the convolution. Default: 1
        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
Hang Zhang's avatar
sync BN  
Hang Zhang committed
227
228
229
230
        output_padding (int or tuple, optional): Zero-padding added to one side of the output.
          Default: 0
        groups (int, optional): Number of blocked connections from input channels to output
          channels. Default: 1
Hang Zhang's avatar
path  
Hang Zhang committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        bias (bool, optional): If True, adds a learnable bias to the output. Default: True
        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
        scale_factor (int): scaling factor for upsampling convolution. Default: 1

    Shape:
        - Input: :math:`(N, C_{in}, H_{in}, W_{in})`
        - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
          :math:`H_{out} = scale * (H_{in} - 1) * stride[0] - 2 * padding[0] + kernel\_size[0] + output\_padding[0]`
          :math:`W_{out} = scale * (W_{in} - 1) * stride[1] - 2 * padding[1] + kernel\_size[1] + output\_padding[1]`

    Attributes:
        weight (Tensor): the learnable weights of the module of shape
                         (in_channels, scale * scale * out_channels, kernel_size[0], kernel_size[1])
        bias (Tensor):   the learnable bias of the module of shape (scale * scale * out_channels)

Hang Zhang's avatar
v0.2.0  
Hang Zhang committed
246
    Examples:
Hang Zhang's avatar
path  
Hang Zhang committed
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        >>> # With square kernels and equal stride
        >>> m = nn.UpsampleCov2d(16, 33, 3, stride=2)
        >>> # non-square kernels and unequal stride and with padding
        >>> m = nn.UpsampleCov2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
        >>> input = autograd.Variable(torch.randn(20, 16, 50, 100))
        >>> output = m(input)
        >>> # exact output size can be also specified as an argument
        >>> input = autograd.Variable(torch.randn(1, 16, 12, 12))
        >>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
        >>> upsample = nn.UpsampleCov2d(16, 16, 3, stride=2, padding=1)
        >>> h = downsample(input)
        >>> h.size()
        torch.Size([1, 16, 6, 6])
        >>> output = upsample(h, output_size=input.size())
        >>> output.size()
        torch.Size([1, 16, 12, 12])

    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
Hang Zhang's avatar
sync BN  
Hang Zhang committed
266
                 padding=0, dilation=1, groups=1, scale_factor=1,
Hang Zhang's avatar
path  
Hang Zhang committed
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
                 bias=True):
        super(UpsampleConv2d, self).__init__()
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.scale_factor = scale_factor
        self.weight = Parameter(torch.Tensor(
Hang Zhang's avatar
sync BN  
Hang Zhang committed
286
            out_channels * scale_factor * scale_factor,
Hang Zhang's avatar
path  
Hang Zhang committed
287
288
            in_channels // groups, *kernel_size))
        if bias:
Hang Zhang's avatar
sync BN  
Hang Zhang committed
289
290
            self.bias = Parameter(torch.Tensor(
                out_channels * scale_factor * scale_factor))
Hang Zhang's avatar
path  
Hang Zhang committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
Hang Zhang's avatar
sync BN  
Hang Zhang committed
305
306
307
        out = F.conv2d(input, self.weight, self.bias, self.stride,
                       self.padding, self.dilation, self.groups)
        return F.pixel_shuffle(out, self.scale_factor)