"vscode:/vscode.git/clone" did not exist on "b34acbdcbc6b04473825fee716666ee26a0f87b2"
__init__.py 8.78 KB
Newer Older
Hang Zhang's avatar
init  
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree 
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Hang Zhang's avatar
Hang Zhang committed
11
import threading
Hang Zhang's avatar
init  
Hang Zhang committed
12
import torch
Hang Zhang's avatar
Hang Zhang committed
13
import torch.cuda.nccl as nccl
Hang Zhang's avatar
Hang Zhang committed
14
import torch.nn as nn
Hang Zhang's avatar
Hang Zhang committed
15
16
from torch.autograd import Function, Variable
from torch.nn.parameter import Parameter
Hang Zhang's avatar
init  
Hang Zhang committed
17
18
19
20
21
from ._ext import encoding_lib

class aggregate(Function):
	def forward(self, A, R):
		# A \in(BxNxK) R \in(BxNxKxD) => E \in(BxNxD)
Hang Zhang's avatar
Hang Zhang committed
22
		self.save_for_backward(A, R)
Hang Zhang's avatar
init  
Hang Zhang committed
23
24
25
		B, N, K, D = R.size()
		E = A.new(B,K,D)
		# TODO support cpu backend
Hang Zhang's avatar
Hang Zhang committed
26
27
28
29
30
31
    if isinstance(A, torch.cuda.FloatTensor):
		    encoding_lib.Encoding_Float_aggregate_forward(E, A, R)
    elif isinstance(A, torch.cuda.DoubleTensor):
		    encoding_lib.Encoding_Double_aggregate_forward(E, A, R)
    else:
        raise RuntimeError('unimplemented')
Hang Zhang's avatar
init  
Hang Zhang committed
32
33
		return E

Hang Zhang's avatar
Hang Zhang committed
34
35
	def backward(self, gradE):
		A, R = self.saved_tensors
Hang Zhang's avatar
Hang Zhang committed
36
37
		gradA = A.new().resize_as_(A)
		gradR = R.new().resize_as_(R)
Hang Zhang's avatar
Hang Zhang committed
38
39
40
41
42
43
44
45
    if isinstance(A, torch.cuda.FloatTensor):
        encoding_lib.Encoding_Float_aggregate_backward(gradA, gradR, gradE, 
                A, R)
    elif isinstance(A, torch.cuda.DoubleTensor):
        encoding_lib.Encoding_Double_aggregate_backward(gradA, gradR, gradE, 
                A, R)
    else:
        raise RuntimeError('unimplemented')
Hang Zhang's avatar
Hang Zhang committed
46
		return gradA, gradR
Hang Zhang's avatar
init  
Hang Zhang committed
47
48


Hang Zhang's avatar
Hang Zhang committed
49
class Aggregate(nn.Module):
Hang Zhang's avatar
init  
Hang Zhang committed
50
51
	def forward(self, A, R):
		return aggregate()(A, R)
Hang Zhang's avatar
Hang Zhang committed
52

Hang Zhang's avatar
Hang Zhang committed
53

Hang Zhang's avatar
Hang Zhang committed
54
55
56
57
58
59
60
61
62
63
64
class Encoding(nn.Module):
	def __init__(self, D, K):
		super(Encoding, self).__init__()
		# init codewords and smoothing factor
		self.D, self.K = D, K
		self.codewords = nn.Parameter(torch.Tensor(K, D), requires_grad=True)
		self.scale = nn.Parameter(torch.Tensor(K), requires_grad=True) 
		self.softmax = nn.Softmax()
		self.reset_params()
		
	def reset_params(self):
Hang Zhang's avatar
Hang Zhang committed
65
66
67
68
		std1 = 1./((self.K*self.D)**(1/2))
		std2 = 1./((self.K)**(1/2))
		self.codewords.data.uniform_(-std1, std1)
		self.scale.data.uniform_(-std2, std2)
Hang Zhang's avatar
Hang Zhang committed
69
70
71
72

	def forward(self, X):
		# input X is a 4D tensor
		assert(X.size(1)==self.D,"Encoding Layer incompatible input channels!")
Hang Zhang's avatar
Hang Zhang committed
73
74
75
76
77
		unpacked = False
		if X.dim() == 3:
			unpacked = True
			X = X.unsqueeze(0)

Hang Zhang's avatar
Hang Zhang committed
78
79
80
81
82
83
84
85
86
87
88
89
90
		B, N, K, D = X.size(0), X.size(2)*X.size(3), self.K, self.D
		# reshape input
		X = X.view(B,D,-1).transpose(1,2)
		# calculate residuals
		R = X.contiguous().view(B,N,1,D).expand(B,N,K,D) - self.codewords.view(
					1,1,K,D).expand(B,N,K,D)
		# assignment weights
		A = R
		A = A.pow(2).sum(3).view(B,N,K)
		A = A*self.scale.view(1,1,K).expand_as(A)
		A = self.softmax(A.view(B*N,K)).view(B,N,K)
		# aggregate
		E = aggregate()(A, R)
Hang Zhang's avatar
Hang Zhang committed
91
92
93

		if unpacked:
			E = E.squeeze(0)
Hang Zhang's avatar
Hang Zhang committed
94
95
96
97
98
		return E

	def __repr__(self):
		return self.__class__.__name__ + '(' \
			+ 'N x ' + str(self.D) + '=>' + str(self.K) + 'x' + str(self.D) + ')'
Hang Zhang's avatar
Hang Zhang committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

class sum_square(Function):
    def forward(ctx, input):
        ctx.save_for_backward(input)
        B,C,H,W = input.size()
        with torch.cuda.device_of(input):
            xsum    = input.new().resize_(C).zero_()
            xsquare = input.new().resize_(C).zero_()
        if isinstance(input, torch.cuda.FloatTensor):
            with torch.cuda.device_of(input):
                encoding_lib.Encoding_Float_sum_square_Forward(
                    input.view(B,C,-1), xsum, xsquare)
        elif isinstance(input, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(input):
                encoding_lib.Encoding_Double_sum_square_Forward( 
                    input.view(B,C,-1), xsum, xsquare)
        else:
            raise RuntimeError('unimplemented') 
        return xsum, xsquare

    def backward(ctx, gradSum, gradSquare):
        input, = ctx.saved_tensors
        B,C,H,W = input.size()
        with torch.cuda.device_of(input):
            gradInput = input.new().resize_(B,C,H*W).zero_()
        #    gradSum.view(1,C,1,1).expand_as(input) + \
        #   2*gradSquare.view(1,C,1,1).expand_as(input)*input
        if isinstance(input, torch.cuda.FloatTensor):
            with torch.cuda.device_of(input):
                encoding_lib.Encoding_Float_sum_square_Backward(
                    gradInput, input.view(B,C,-1), gradSum, gradSquare)
        elif isinstance(input, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(input):
                encoding_lib.Encoding_Double_sum_square_Backward( 
                    gradInput, input.view(B,C,-1), gradSum, gradSquare)
        else:
            raise RuntimeError('unimplemented') 
        return gradInput.view(B,C,H,W)

class batchnormtrain(Function):
    def forward(ctx, input, gamma, beta, mean, std):
        ctx.save_for_backward(input, gamma, beta, mean, std)
        assert(input.dim()==3)
        with torch.cuda.device_of(input):
            invstd = 1.0 / std
            output = input.new().resize_as_(input)
        if isinstance(input, torch.cuda.FloatTensor):
            with torch.cuda.device_of(input):
                encoding_lib.Encoding_Float_batchnorm_Forward(output, 
                    input, mean, invstd, gamma, beta)
        elif isinstance(input, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(input):
                encoding_lib.Encoding_Double_batchnorm_Forward(output, 
                    input, mean, invstd, gamma, beta)
        else:
            raise RuntimeError('unimplemented')
        return output 

    def backward(ctx, gradOutput):
        input, gamma, beta, mean, std = ctx.saved_tensors
        invstd = 1.0 / std
        with torch.cuda.device_of(input):
            gradInput = gradOutput.new().resize_as_(input).zero_()
            gradGamma = gradOutput.new().resize_as_(gamma).zero_()
            gradBeta  = gradOutput.new().resize_as_(beta).zero_()
            gradMean  = gradOutput.new().resize_as_(mean).zero_()
            gradStd   = gradOutput.new().resize_as_(std).zero_()

        if isinstance(input, torch.cuda.FloatTensor):
            with torch.cuda.device_of(input):
                encoding_lib.Encoding_Float_batchnorm_Backward(
                    gradOutput, input, gradInput, gradGamma, gradBeta, 
                    mean, invstd, gamma, beta, gradMean, gradStd,
                    True) 
        elif isinstance(input, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(input):
                encoding_lib.Encoding_Double_batchnorm_Backward(
                    gradOutput, input, gradInput, gradGamma, gradBeta, 
                    mean, invstd, gamma, beta, gradMean, gradStd,
                    True) 
        else:
            raise RuntimeError('unimplemented')
        return gradInput, gradGamma, gradBeta, gradMean, gradStd

class batchnormeval(Function):
    def forward(ctx, input, gamma, beta, mean, std):
        ctx.save_for_backward(input, gamma, beta, mean, std)
        assert(input.dim()==3)
        with torch.cuda.device_of(input):
            invstd = 1.0 / std
            output = input.new().resize_as_(input)
        if isinstance(input, torch.cuda.FloatTensor):
            with torch.cuda.device_of(input):
                encoding_lib.Encoding_Float_batchnorm_Forward(output, 
                    input, mean, invstd, gamma, beta)
        elif isinstance(input, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(input):
                encoding_lib.Encoding_Double_batchnorm_Forward(output, 
                    input, mean, invstd, gamma, beta)
        else:
            raise RuntimeError('unimplemented')
        return output 

    def backward(ctx, gradOutput):
        input, gamma, beta, mean, std = ctx.saved_tensors
        invstd = 1.0 / std
        with torch.cuda.device_of(input):
            gradInput = gradOutput.new().resize_as_(input).zero_()
            gradGamma = gradOutput.new().resize_as_(gamma).zero_()
            gradBeta  = gradOutput.new().resize_as_(beta).zero_()
            gradMean  = gradOutput.new().resize_as_(mean).zero_()
            gradStd   = gradOutput.new().resize_as_(std).zero_()
        if isinstance(input, torch.cuda.FloatTensor):
            with torch.cuda.device_of(input):
                encoding_lib.Encoding_Float_batchnorm_Backward(
                    gradOutput, input, gradInput, gradGamma, gradBeta, 
                    mean, invstd, gamma, beta, gradMean, gradStd,
                    False) 
        elif isinstance(input, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(input):
                encoding_lib.Encoding_Double_batchnorm_Backward(
                    gradOutput, input, gradInput, gradGamma, gradBeta, 
                    mean, invstd, gamma, beta, gradMean, gradStd,
                    False) 
        else:
            raise RuntimeError('unimplemented')
        return gradInput, gradGamma, gradBeta, gradMean, gradStd