dist_syncbn.py 3.75 KB
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: zhanghang0704@gmail.com
## Copyright (c) 2020
##
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

import torch
from torch.autograd.function import Function
Hang Zhang's avatar
Hang Zhang committed
11
12
13
14

from encoding import cpu
if torch.cuda.device_count() > 0:
    from encoding import gpu
Hang Zhang's avatar
Hang Zhang committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

__all__ = ['dist_syncbatchnorm']

class dist_syncbatchnorm_(Function):
    @staticmethod
    def forward(ctx, x, gamma, beta, running_mean, running_var, eps, momentum, training, process_group):
        x = x.contiguous()
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.process_group = process_group

        if not ctx.training:
            _ex, _var = running_mean.contiguous(), running_var.contiguous()
            _exs = _var + _ex ** 2 
            if x.is_cuda:
Hang Zhang's avatar
Hang Zhang committed
31
                y = gpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
Hang Zhang's avatar
Hang Zhang committed
32
            else:
Hang Zhang's avatar
Hang Zhang committed
33
                y = cpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
Hang Zhang's avatar
Hang Zhang committed
34
35
36
37
38
39
40
41
            ctx.save_for_backward(x, _ex, _exs, gamma, beta)
            return y

        size = x.numel() // x.size(1)
        if size == 1:
            raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))

        if x.is_cuda:
Hang Zhang's avatar
Hang Zhang committed
42
            _ex, _exs = gpu.expectation_forward(x)
Hang Zhang's avatar
Hang Zhang committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        else:
            raise NotImplemented

        count = torch.Tensor([1]).to(x.device)
        count_all_reduce = torch.distributed.all_reduce(count, group=process_group, async_op=True)
        _ex_all_reduce = torch.distributed.all_reduce(_ex, group=process_group, async_op=True)
        _exs_all_reduce = torch.distributed.all_reduce(_exs, group=process_group, async_op=True)

        count_all_reduce.wait()
        _ex_all_reduce.wait()
        _exs_all_reduce.wait()

        _ex = _ex / count
        _exs = _exs / count

        # Update running stats
        _var = _exs - _ex ** 2
        running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
        running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)

        # Mark in-place modified tensors
        ctx.mark_dirty(running_mean, running_var)

        # BN forward + activation
        if x.is_cuda:
Hang Zhang's avatar
Hang Zhang committed
68
            y = gpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
Hang Zhang's avatar
Hang Zhang committed
69
        else:
Hang Zhang's avatar
Hang Zhang committed
70
            y = cpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
Hang Zhang's avatar
Hang Zhang committed
71
72
73
74
75
76
77
78
79
80
81
82

        ctx.save_for_backward(x, _ex, _exs, gamma, beta)
        return y

    @staticmethod
    def backward(ctx, dz):
        x, _ex, _exs, gamma, beta = ctx.saved_tensors
        dz = dz.contiguous()

        # BN backward
        if dz.is_cuda:
            dx, _dex, _dexs, dgamma, dbeta = \
Hang Zhang's avatar
Hang Zhang committed
83
                gpu.batchnorm_backward(dz, x, _ex, _exs, gamma, beta, ctx.eps)
Hang Zhang's avatar
Hang Zhang committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        else:
            raise NotImplemented

        if ctx.training:
            process_group = ctx.process_group
            count = torch.Tensor([1]).to(x.device)
            count_all_reduce = torch.distributed.all_reduce(count, group=process_group, async_op=True)
            _dex_all_reduce = torch.distributed.all_reduce(_dex, group=process_group, async_op=True)
            _dexs_all_reduce = torch.distributed.all_reduce(_dexs, group=process_group, async_op=True)

            count_all_reduce.wait()
            _dex_all_reduce.wait()
            _dexs_all_reduce.wait()

            _dex = _dex / count
            _dexs = _dexs / count

            if x.is_cuda:
Hang Zhang's avatar
Hang Zhang committed
102
                dx_ = gpu.expectation_backward(x, _dex, _dexs)
Hang Zhang's avatar
Hang Zhang committed
103
104
105
106
107
108
109
            else:
                raise NotImplemented
            dx = dx + dx_

        return dx, dgamma, dbeta, None, None, None, None, None, None

dist_syncbatchnorm = dist_syncbatchnorm_.apply