"vscode:/vscode.git/clone" did not exist on "4a4772ae03c8b29834efbfa1175ba6abeafa77c9"
dist_syncbn.py 3.71 KB
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: zhanghang0704@gmail.com
## Copyright (c) 2020
##
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

import torch
from torch.autograd.function import Function
from .. import lib

__all__ = ['dist_syncbatchnorm']

class dist_syncbatchnorm_(Function):
    @staticmethod
    def forward(ctx, x, gamma, beta, running_mean, running_var, eps, momentum, training, process_group):
        x = x.contiguous()
        ctx.training = training
        ctx.momentum = momentum
        ctx.eps = eps
        ctx.process_group = process_group

        if not ctx.training:
            _ex, _var = running_mean.contiguous(), running_var.contiguous()
            _exs = _var + _ex ** 2 
            if x.is_cuda:
                y = lib.gpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
            else:
                y = lib.cpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
            ctx.save_for_backward(x, _ex, _exs, gamma, beta)
            return y

        size = x.numel() // x.size(1)
        if size == 1:
            raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))

        if x.is_cuda:
            _ex, _exs = lib.gpu.expectation_forward(x)
        else:
            raise NotImplemented

        count = torch.Tensor([1]).to(x.device)
        count_all_reduce = torch.distributed.all_reduce(count, group=process_group, async_op=True)
        _ex_all_reduce = torch.distributed.all_reduce(_ex, group=process_group, async_op=True)
        _exs_all_reduce = torch.distributed.all_reduce(_exs, group=process_group, async_op=True)

        count_all_reduce.wait()
        _ex_all_reduce.wait()
        _exs_all_reduce.wait()

        _ex = _ex / count
        _exs = _exs / count

        # Update running stats
        _var = _exs - _ex ** 2
        running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
        running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)

        # Mark in-place modified tensors
        ctx.mark_dirty(running_mean, running_var)

        # BN forward + activation
        if x.is_cuda:
            y = lib.gpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)
        else:
            y = lib.cpu.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps)

        ctx.save_for_backward(x, _ex, _exs, gamma, beta)
        return y

    @staticmethod
    def backward(ctx, dz):
        x, _ex, _exs, gamma, beta = ctx.saved_tensors
        dz = dz.contiguous()

        # BN backward
        if dz.is_cuda:
            dx, _dex, _dexs, dgamma, dbeta = \
                lib.gpu.batchnorm_backward(dz, x, _ex, _exs, gamma, beta, ctx.eps)
        else:
            raise NotImplemented

        if ctx.training:
            process_group = ctx.process_group
            count = torch.Tensor([1]).to(x.device)
            count_all_reduce = torch.distributed.all_reduce(count, group=process_group, async_op=True)
            _dex_all_reduce = torch.distributed.all_reduce(_dex, group=process_group, async_op=True)
            _dexs_all_reduce = torch.distributed.all_reduce(_dexs, group=process_group, async_op=True)

            count_all_reduce.wait()
            _dex_all_reduce.wait()
            _dexs_all_reduce.wait()

            _dex = _dex / count
            _dexs = _dexs / count

            if x.is_cuda:
                dx_ = lib.gpu.expectation_backward(x, _dex, _dexs)
            else:
                raise NotImplemented
            dx = dx + dx_

        return dx, dgamma, dbeta, None, None, None, None, None, None

dist_syncbatchnorm = dist_syncbatchnorm_.apply