"tests/vscode:/vscode.git/clone" did not exist on "cc1f9a2ce33222b4c3d103c39272b28b72e45fba"
syncbn.py 10.5 KB
Newer Older
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
1
2
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
Zhang's avatar
v0.4.2  
Zhang committed
3
4
## Email: zhanghang0704@gmail.com
## Copyright (c) 2018
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
5
6
##
## This source code is licensed under the MIT-style license found in the
Hang Zhang's avatar
sync BN  
Hang Zhang committed
7
## LICENSE file in the root directory of this source tree
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
8
9
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Zhang's avatar
Zhang committed
10
"""Synchronized Cross-GPU Batch Normalization functions"""
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
11
import torch
Hang Zhang's avatar
Hang Zhang committed
12
import torch.cuda.comm as comm
Hang Zhang's avatar
Hang Zhang committed
13
from torch.autograd import Function
Hang Zhang's avatar
Hang Zhang committed
14
from torch.autograd.function import once_differentiable
Hang Zhang's avatar
Hang Zhang committed
15
16
17
18

from encoding import cpu
if torch.cuda.device_count() > 0:
    from encoding import gpu
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
19

Hang Zhang's avatar
Hang Zhang committed
20
__all__ = ['moments', 'syncbatchnorm', 'inp_syncbatchnorm']
Zhang's avatar
Zhang committed
21

Hang Zhang's avatar
Hang Zhang committed
22
class moments_(Function):
Hang Zhang's avatar
Hang Zhang committed
23
24
25
    @staticmethod
    def forward(ctx, x):
        if x.is_cuda:
Hang Zhang's avatar
Hang Zhang committed
26
            ex, ex2 = gpu.expectation_forward(x)
Hang Zhang's avatar
Hang Zhang committed
27
28
        else:
            raise NotImplemented
Hang Zhang's avatar
Hang Zhang committed
29
        ctx.save_for_backward(x)
Hang Zhang's avatar
Hang Zhang committed
30
        return ex, ex2
Zhang's avatar
Zhang committed
31

Hang Zhang's avatar
sync BN  
Hang Zhang committed
32
    @staticmethod
Hang Zhang's avatar
Hang Zhang committed
33
    def backward(ctx, dex, dex2):
Hang Zhang's avatar
Hang Zhang committed
34
35
        x, = ctx.saved_tensors
        if dex.is_cuda:
Hang Zhang's avatar
Hang Zhang committed
36
            dx = gpu.expectation_backward(x, dex, dex2)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
37
        else:
Hang Zhang's avatar
Hang Zhang committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
            raise NotImplemented
        return dx

class syncbatchnorm_(Function):
    @classmethod
    def forward(cls, ctx, x, gamma, beta, running_mean, running_var,
                extra, sync=True, training=True, momentum=0.1, eps=1e-05,
                activation="none", slope=0.01):
        # save context
        cls._parse_extra(ctx, extra)
        ctx.sync = sync
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.activation = activation
        ctx.slope = slope
        assert activation == 'none'

        # continous inputs
        x = x.contiguous()
        gamma = gamma.contiguous()
        beta = beta.contiguous()

        if ctx.training:
            if x.is_cuda:
Hang Zhang's avatar
Hang Zhang committed
63
                _ex, _exs = gpu.expectation_forward(x)
Hang Zhang's avatar
Hang Zhang committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
            else:
                raise NotImplemented

            if ctx.sync:
                if ctx.is_master:
                    _ex, _exs = [_ex.unsqueeze(0)], [_exs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        _ex_w, _exs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        _ex.append(_ex_w.unsqueeze(0))
                        _exs.append(_exs_w.unsqueeze(0))

                    _ex = comm.gather(_ex).mean(0)
                    _exs = comm.gather(_exs).mean(0)

                    tensors = comm.broadcast_coalesced((_ex, _exs), [_ex.get_device()] + ctx.worker_ids)
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((_ex, _exs))
                    _ex, _exs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()

            # Update running stats
            _var = _exs - _ex ** 2
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
            running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)

            # Mark in-place modified tensors
            ctx.mark_dirty(running_mean, running_var)
        else:
            _ex, _var = running_mean.contiguous(), running_var.contiguous()
            _exs = _var + _ex ** 2 

        # BN forward + activation
        if x.is_cuda:
Hang Zhang's avatar
Hang Zhang committed
100
            y = gpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
Hang Zhang's avatar
Hang Zhang committed
101
        else:
Hang Zhang's avatar
Hang Zhang committed
102
            y = cpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
Hang Zhang's avatar
Hang Zhang committed
103
104
105
106

        # Output
        ctx.save_for_backward(x, _ex, _exs, gamma, beta)
        return y
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
107

Hang Zhang's avatar
sync BN  
Hang Zhang committed
108
    @staticmethod
Hang Zhang's avatar
Hang Zhang committed
109
110
111
112
113
114
115
116
    @once_differentiable
    def backward(ctx, dz):
        x, _ex, _exs, gamma, beta = ctx.saved_tensors
        dz = dz.contiguous()

        # BN backward
        if dz.is_cuda:
            dx, _dex, _dexs, dgamma, dbeta = \
Hang Zhang's avatar
Hang Zhang committed
117
                gpu.batchnorm_backward(dz, x, _ex, _exs, gamma, beta, ctx.eps)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
118
        else:
Zhang's avatar
v0.4.2  
Zhang committed
119
            raise NotImplemented
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
120

Hang Zhang's avatar
Hang Zhang committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        if ctx.training:
            if ctx.sync:
                if ctx.is_master:
                    _dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        _dex_w, _dexs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        _dex.append(_dex_w.unsqueeze(0))
                        _dexs.append(_dexs_w.unsqueeze(0))

                    _dex = comm.gather(_dex).mean(0)
                    _dexs = comm.gather(_dexs).mean(0)

                    tensors = comm.broadcast_coalesced((_dex, _dexs), [_dex.get_device()] + ctx.worker_ids)
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((_dex, _dexs))
                    _dex, _dexs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()

            if x.is_cuda:
Hang Zhang's avatar
Hang Zhang committed
143
                dx_ = gpu.expectation_backward(x, _dex, _dexs)
Hang Zhang's avatar
Hang Zhang committed
144
145
146
147
148
            else:
                raise NotImplemented
            dx = dx + dx_

        return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
149

Zhang's avatar
v0.4.2  
Zhang committed
150
    @staticmethod
Hang Zhang's avatar
Hang Zhang committed
151
152
153
154
155
156
157
158
159
160
161
162
163
    def _parse_extra(ctx, extra):
        ctx.is_master = extra["is_master"]
        if ctx.is_master:
            ctx.master_queue = extra["master_queue"]
            ctx.worker_queues = extra["worker_queues"]
            ctx.worker_ids = extra["worker_ids"]
        else:
            ctx.master_queue = extra["master_queue"]
            ctx.worker_queue = extra["worker_queue"]

def _act_forward(ctx, x):
    if ctx.activation.lower() == "leaky_relu":
        if x.is_cuda:
Hang Zhang's avatar
Hang Zhang committed
164
            gpu.leaky_relu_forward(x, ctx.slope)
Hang Zhang's avatar
Hang Zhang committed
165
166
167
168
169
170
171
172
        else:
            raise NotImplemented
    else:
        assert activation == 'none'

def _act_backward(ctx, x, dx):
    if ctx.activation.lower() == "leaky_relu":
        if x.is_cuda:
Hang Zhang's avatar
Hang Zhang committed
173
            gpu.leaky_relu_backward(x, dx, ctx.slope)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
174
        else:
Hang Zhang's avatar
Hang Zhang committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
            raise NotImplemented
    else:
        assert activation == 'none'

class inp_syncbatchnorm_(Function):
    @classmethod
    def forward(cls, ctx, x, gamma, beta, running_mean, running_var,
                extra, sync=True, training=True, momentum=0.1, eps=1e-05,
                activation="none", slope=0.01):
        # save context
        cls._parse_extra(ctx, extra)
        ctx.sync = sync
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.activation = activation
        ctx.slope = slope

        # continous inputs
        x = x.contiguous()
        gamma = gamma.contiguous()
        beta = beta.contiguous()

        if ctx.training:
            if x.is_cuda:
Hang Zhang's avatar
Hang Zhang committed
200
                _ex, _exs = gpu.expectation_forward(x)
Hang Zhang's avatar
Hang Zhang committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
            else:
                raise NotImplemented

            if ctx.sync:
                if ctx.is_master:
                    _ex, _exs = [_ex.unsqueeze(0)], [_exs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        _ex_w, _exs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        _ex.append(_ex_w.unsqueeze(0))
                        _exs.append(_exs_w.unsqueeze(0))

                    _ex = comm.gather(_ex).mean(0)
                    _exs = comm.gather(_exs).mean(0)

                    tensors = comm.broadcast_coalesced((_ex, _exs), [_ex.get_device()] + ctx.worker_ids)
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((_ex, _exs))
                    _ex, _exs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()

            # Update running stats
            _var = _exs - _ex ** 2
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
            running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)

            # Mark in-place modified tensors
            ctx.mark_dirty(x, running_mean, running_var)
        else:
            _ex, _var = running_mean.contiguous(), running_var.contiguous()
            _exs = _var + _ex ** 2 
            ctx.mark_dirty(x)

        # BN forward + activation
        if x.is_cuda:
Hang Zhang's avatar
Hang Zhang committed
238
            gpu.batchnorm_inp_forward(x, _ex, _exs, gamma, beta, ctx.eps)
Hang Zhang's avatar
Hang Zhang committed
239
240
241
242
243
244
245
246
        else:
            raise NotImplemented

        _act_forward(ctx, x)

        # Output
        ctx.save_for_backward(x, _ex, _exs, gamma, beta)
        return x
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
247

Zhang's avatar
v0.4.2  
Zhang committed
248
    @staticmethod
Hang Zhang's avatar
Hang Zhang committed
249
250
251
252
253
254
255
256
257
258
259
    @once_differentiable
    def backward(ctx, dz):
        z, _ex, _exs, gamma, beta = ctx.saved_tensors
        dz = dz.contiguous()

        # Undo activation
        _act_backward(ctx, z, dz)

        # BN backward
        if dz.is_cuda:
            dx, _dex, _dexs, dgamma, dbeta = \
Hang Zhang's avatar
Hang Zhang committed
260
                gpu.batchnorm_inp_backward(dz, z, _ex, _exs, gamma, beta, ctx.eps)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
261
        else:
Zhang's avatar
v0.4.2  
Zhang committed
262
            raise NotImplemented
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
263

Hang Zhang's avatar
Hang Zhang committed
264
265
266
267
268
269
270
271
272
        if ctx.training:
            if ctx.sync:
                if ctx.is_master:
                    _dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        _dex_w, _dexs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        _dex.append(_dex_w.unsqueeze(0))
                        _dexs.append(_dexs_w.unsqueeze(0))
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
273

Hang Zhang's avatar
Hang Zhang committed
274
275
                    _dex = comm.gather(_dex).mean(0)
                    _dexs = comm.gather(_dexs).mean(0)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
276

Hang Zhang's avatar
Hang Zhang committed
277
278
279
280
281
282
283
                    tensors = comm.broadcast_coalesced((_dex, _dexs), [_dex.get_device()] + ctx.worker_ids)
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((_dex, _dexs))
                    _dex, _dexs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
284

Hang Zhang's avatar
Hang Zhang committed
285
            if z.is_cuda:
Hang Zhang's avatar
Hang Zhang committed
286
                gpu.expectation_inp_backward(dx, z, _dex, _dexs, _ex, _exs, gamma, beta, ctx.eps)
Hang Zhang's avatar
Hang Zhang committed
287
288
            else:
                raise NotImplemented
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
289

Hang Zhang's avatar
Hang Zhang committed
290
        return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
291

Hang Zhang's avatar
Hang Zhang committed
292
293
294
295
296
297
298
299
300
301
    @staticmethod
    def _parse_extra(ctx, extra):
        ctx.is_master = extra["is_master"]
        if ctx.is_master:
            ctx.master_queue = extra["master_queue"]
            ctx.worker_queues = extra["worker_queues"]
            ctx.worker_ids = extra["worker_ids"]
        else:
            ctx.master_queue = extra["master_queue"]
            ctx.worker_queue = extra["worker_queue"]
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
302

Hang Zhang's avatar
Hang Zhang committed
303
moments = moments_.apply
Hang Zhang's avatar
Hang Zhang committed
304
305
syncbatchnorm = syncbatchnorm_.apply
inp_syncbatchnorm = inp_syncbatchnorm_.apply