syncbn.py 10.4 KB
Newer Older
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
1
2
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
Zhang's avatar
v0.4.2  
Zhang committed
3
4
## Email: zhanghang0704@gmail.com
## Copyright (c) 2018
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
5
6
##
## This source code is licensed under the MIT-style license found in the
Hang Zhang's avatar
sync BN  
Hang Zhang committed
7
## LICENSE file in the root directory of this source tree
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
8
9
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

Zhang's avatar
Zhang committed
10
"""Synchronized Cross-GPU Batch Normalization functions"""
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
11
import torch
Hang Zhang's avatar
Hang Zhang committed
12
import torch.cuda.comm as comm
Hang Zhang's avatar
Hang Zhang committed
13
from torch.autograd import Function
Hang Zhang's avatar
Hang Zhang committed
14
from torch.autograd.function import once_differentiable
Zhang's avatar
v0.4.2  
Zhang committed
15
from .. import lib
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
16

Hang Zhang's avatar
Hang Zhang committed
17
__all__ = ['moments', 'syncbatchnorm', 'inp_syncbatchnorm']
Zhang's avatar
Zhang committed
18

Hang Zhang's avatar
Hang Zhang committed
19
20
21
22
23
24
25
26
class moments(Function):
    @staticmethod
    def forward(ctx, x):
        if x.is_cuda:
            ex, ex2 = lib.gpu.expectation_forward(x)
        else:
            raise NotImplemented
        return ex, ex2
Zhang's avatar
Zhang committed
27

Hang Zhang's avatar
sync BN  
Hang Zhang committed
28
    @staticmethod
Hang Zhang's avatar
Hang Zhang committed
29
30
31
    def backward(ctx, dex, dex2):
        if x.is_cuda:
            dx = lib.gpu.expectation_backward(x, dex, dex2)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
32
        else:
Hang Zhang's avatar
Hang Zhang committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            raise NotImplemented
        return dx

class syncbatchnorm_(Function):
    @classmethod
    def forward(cls, ctx, x, gamma, beta, running_mean, running_var,
                extra, sync=True, training=True, momentum=0.1, eps=1e-05,
                activation="none", slope=0.01):
        # save context
        cls._parse_extra(ctx, extra)
        ctx.sync = sync
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.activation = activation
        ctx.slope = slope
        assert activation == 'none'

        # continous inputs
        x = x.contiguous()
        gamma = gamma.contiguous()
        beta = beta.contiguous()

        if ctx.training:
            if x.is_cuda:
                _ex, _exs = lib.gpu.expectation_forward(x)
            else:
                raise NotImplemented

            if ctx.sync:
                if ctx.is_master:
                    _ex, _exs = [_ex.unsqueeze(0)], [_exs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        _ex_w, _exs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        _ex.append(_ex_w.unsqueeze(0))
                        _exs.append(_exs_w.unsqueeze(0))

                    _ex = comm.gather(_ex).mean(0)
                    _exs = comm.gather(_exs).mean(0)

                    tensors = comm.broadcast_coalesced((_ex, _exs), [_ex.get_device()] + ctx.worker_ids)
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((_ex, _exs))
                    _ex, _exs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()

            # Update running stats
            _var = _exs - _ex ** 2
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
            running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)

            # Mark in-place modified tensors
            ctx.mark_dirty(running_mean, running_var)
        else:
            _ex, _var = running_mean.contiguous(), running_var.contiguous()
            _exs = _var + _ex ** 2 

        # BN forward + activation
        if x.is_cuda:
            y = lib.gpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
        else:
            y = lib.cpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)

        # Output
        ctx.save_for_backward(x, _ex, _exs, gamma, beta)
        return y
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
102

Hang Zhang's avatar
sync BN  
Hang Zhang committed
103
    @staticmethod
Hang Zhang's avatar
Hang Zhang committed
104
105
106
107
108
109
110
111
112
    @once_differentiable
    def backward(ctx, dz):
        x, _ex, _exs, gamma, beta = ctx.saved_tensors
        dz = dz.contiguous()

        # BN backward
        if dz.is_cuda:
            dx, _dex, _dexs, dgamma, dbeta = \
                lib.gpu.batchnorm_backward(dz, x, _ex, _exs, gamma, beta, ctx.eps)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
113
        else:
Zhang's avatar
v0.4.2  
Zhang committed
114
            raise NotImplemented
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
115

Hang Zhang's avatar
Hang Zhang committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        if ctx.training:
            if ctx.sync:
                if ctx.is_master:
                    _dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        _dex_w, _dexs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        _dex.append(_dex_w.unsqueeze(0))
                        _dexs.append(_dexs_w.unsqueeze(0))

                    _dex = comm.gather(_dex).mean(0)
                    _dexs = comm.gather(_dexs).mean(0)

                    tensors = comm.broadcast_coalesced((_dex, _dexs), [_dex.get_device()] + ctx.worker_ids)
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((_dex, _dexs))
                    _dex, _dexs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()

            if x.is_cuda:
                dx_ = lib.gpu.expectation_backward(x, _dex, _dexs)
            else:
                raise NotImplemented
            dx = dx + dx_

        return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
144

Zhang's avatar
v0.4.2  
Zhang committed
145
    @staticmethod
Hang Zhang's avatar
Hang Zhang committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    def _parse_extra(ctx, extra):
        ctx.is_master = extra["is_master"]
        if ctx.is_master:
            ctx.master_queue = extra["master_queue"]
            ctx.worker_queues = extra["worker_queues"]
            ctx.worker_ids = extra["worker_ids"]
        else:
            ctx.master_queue = extra["master_queue"]
            ctx.worker_queue = extra["worker_queue"]

def _act_forward(ctx, x):
    if ctx.activation.lower() == "leaky_relu":
        if x.is_cuda:
            lib.gpu.leaky_relu_forward(x, ctx.slope)
        else:
            raise NotImplemented
    else:
        assert activation == 'none'

def _act_backward(ctx, x, dx):
    if ctx.activation.lower() == "leaky_relu":
        if x.is_cuda:
            lib.gpu.leaky_relu_backward(x, dx, ctx.slope)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
169
        else:
Hang Zhang's avatar
Hang Zhang committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
            raise NotImplemented
    else:
        assert activation == 'none'

class inp_syncbatchnorm_(Function):
    @classmethod
    def forward(cls, ctx, x, gamma, beta, running_mean, running_var,
                extra, sync=True, training=True, momentum=0.1, eps=1e-05,
                activation="none", slope=0.01):
        # save context
        cls._parse_extra(ctx, extra)
        ctx.sync = sync
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.activation = activation
        ctx.slope = slope

        # continous inputs
        x = x.contiguous()
        gamma = gamma.contiguous()
        beta = beta.contiguous()

        if ctx.training:
            if x.is_cuda:
                _ex, _exs = lib.gpu.expectation_forward(x)
            else:
                raise NotImplemented

            if ctx.sync:
                if ctx.is_master:
                    _ex, _exs = [_ex.unsqueeze(0)], [_exs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        _ex_w, _exs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        _ex.append(_ex_w.unsqueeze(0))
                        _exs.append(_exs_w.unsqueeze(0))

                    _ex = comm.gather(_ex).mean(0)
                    _exs = comm.gather(_exs).mean(0)

                    tensors = comm.broadcast_coalesced((_ex, _exs), [_ex.get_device()] + ctx.worker_ids)
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((_ex, _exs))
                    _ex, _exs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()

            # Update running stats
            _var = _exs - _ex ** 2
            running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
            running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)

            # Mark in-place modified tensors
            ctx.mark_dirty(x, running_mean, running_var)
        else:
            _ex, _var = running_mean.contiguous(), running_var.contiguous()
            _exs = _var + _ex ** 2 
            ctx.mark_dirty(x)

        # BN forward + activation
        if x.is_cuda:
            lib.gpu.batchnorm_inp_forward(x, _ex, _exs, gamma, beta, ctx.eps)
        else:
            raise NotImplemented

        _act_forward(ctx, x)

        # Output
        ctx.save_for_backward(x, _ex, _exs, gamma, beta)
        return x
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
242

Zhang's avatar
v0.4.2  
Zhang committed
243
    @staticmethod
Hang Zhang's avatar
Hang Zhang committed
244
245
246
247
248
249
250
251
252
253
254
255
    @once_differentiable
    def backward(ctx, dz):
        z, _ex, _exs, gamma, beta = ctx.saved_tensors
        dz = dz.contiguous()

        # Undo activation
        _act_backward(ctx, z, dz)

        # BN backward
        if dz.is_cuda:
            dx, _dex, _dexs, dgamma, dbeta = \
                lib.gpu.batchnorm_inp_backward(dz, z, _ex, _exs, gamma, beta, ctx.eps)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
256
        else:
Zhang's avatar
v0.4.2  
Zhang committed
257
            raise NotImplemented
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
258

Hang Zhang's avatar
Hang Zhang committed
259
260
261
262
263
264
265
266
267
        if ctx.training:
            if ctx.sync:
                if ctx.is_master:
                    _dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)]
                    for _ in range(ctx.master_queue.maxsize):
                        _dex_w, _dexs_w = ctx.master_queue.get()
                        ctx.master_queue.task_done()
                        _dex.append(_dex_w.unsqueeze(0))
                        _dexs.append(_dexs_w.unsqueeze(0))
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
268

Hang Zhang's avatar
Hang Zhang committed
269
270
                    _dex = comm.gather(_dex).mean(0)
                    _dexs = comm.gather(_dexs).mean(0)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
271

Hang Zhang's avatar
Hang Zhang committed
272
273
274
275
276
277
278
                    tensors = comm.broadcast_coalesced((_dex, _dexs), [_dex.get_device()] + ctx.worker_ids)
                    for ts, queue in zip(tensors[1:], ctx.worker_queues):
                        queue.put(ts)
                else:
                    ctx.master_queue.put((_dex, _dexs))
                    _dex, _dexs = ctx.worker_queue.get()
                    ctx.worker_queue.task_done()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
279

Hang Zhang's avatar
Hang Zhang committed
280
281
282
283
            if z.is_cuda:
                lib.gpu.expectation_inp_backward(dx, z, _dex, _dexs, _ex, _exs, gamma, beta, ctx.eps)
            else:
                raise NotImplemented
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
284

Hang Zhang's avatar
Hang Zhang committed
285
        return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
286

Hang Zhang's avatar
Hang Zhang committed
287
288
289
290
291
292
293
294
295
296
    @staticmethod
    def _parse_extra(ctx, extra):
        ctx.is_master = extra["is_master"]
        if ctx.is_master:
            ctx.master_queue = extra["master_queue"]
            ctx.worker_queues = extra["worker_queues"]
            ctx.worker_ids = extra["worker_ids"]
        else:
            ctx.master_queue = extra["master_queue"]
            ctx.worker_queue = extra["worker_queue"]
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
297

Hang Zhang's avatar
Hang Zhang committed
298
299
syncbatchnorm = syncbatchnorm_.apply
inp_syncbatchnorm = inp_syncbatchnorm_.apply