norm.py 3.98 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
import torch
2
from mmcv.cnn import NORM_LAYERS
zhangwenwei's avatar
zhangwenwei committed
3
4
from torch import distributed as dist
from torch import nn as nn
zhangwenwei's avatar
zhangwenwei committed
5
from torch.autograd.function import Function
zhangwenwei's avatar
zhangwenwei committed
6

zhangwenwei's avatar
zhangwenwei committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

class AllReduce(Function):

    @staticmethod
    def forward(ctx, input):
        input_list = [
            torch.zeros_like(input) for k in range(dist.get_world_size())
        ]
        # Use allgather instead of allreduce in-place operations is unreliable
        dist.all_gather(input_list, input, async_op=False)
        inputs = torch.stack(input_list, dim=0)
        return torch.sum(inputs, dim=0)

    @staticmethod
    def backward(ctx, grad_output):
        dist.all_reduce(grad_output, async_op=False)
        return grad_output


26
@NORM_LAYERS.register_module('naiveSyncBN1d')
zhangwenwei's avatar
zhangwenwei committed
27
class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
zhangwenwei's avatar
zhangwenwei committed
28
    """Syncronized Batch Normalization for 3D Tensors.
zhangwenwei's avatar
zhangwenwei committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

    Note:
        This implementation is modified from
        https://github.com/facebookresearch/detectron2/

        `torch.nn.SyncBatchNorm` has known unknown bugs.
        It produces significantly worse AP (and sometimes goes NaN)
        when the batch size on each worker is quite different
        (e.g., when scale augmentation is used).
        In 3D detection, different workers has points of different shapes,
        whish also cause instability.

        Use this implementation before `nn.SyncBatchNorm` is fixed.
        It is slower than `nn.SyncBatchNorm`.
    """

    def forward(self, input):
        if dist.get_world_size() == 1 or not self.training:
            return super().forward(input)

        assert input.shape[0] > 0, 'SyncBN does not support empty inputs'
        C = input.shape[1]
        mean = torch.mean(input, dim=[0, 2])
        meansqr = torch.mean(input * input, dim=[0, 2])

        vec = torch.cat([mean, meansqr], dim=0)
        vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())

        mean, meansqr = torch.split(vec, C)
        var = meansqr - mean * mean
        self.running_mean += self.momentum * (
            mean.detach() - self.running_mean)
        self.running_var += self.momentum * (var.detach() - self.running_var)

        invstd = torch.rsqrt(var + self.eps)
        scale = self.weight * invstd
        bias = self.bias - mean * scale
        scale = scale.reshape(1, -1, 1)
        bias = bias.reshape(1, -1, 1)
        return input * scale + bias


71
@NORM_LAYERS.register_module('naiveSyncBN2d')
zhangwenwei's avatar
zhangwenwei committed
72
class NaiveSyncBatchNorm2d(nn.BatchNorm2d):
zhangwenwei's avatar
zhangwenwei committed
73
    """Syncronized Batch Normalization for 4D Tensors.
zhangwenwei's avatar
zhangwenwei committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

    Note:
        This implementation is modified from
        https://github.com/facebookresearch/detectron2/

        `torch.nn.SyncBatchNorm` has known unknown bugs.
        It produces significantly worse AP (and sometimes goes NaN)
        when the batch size on each worker is quite different
        (e.g., when scale augmentation is used).
        This phenomenon also occurs when the multi-modality feature fusion
        modules of multi-modality detectors use SyncBN.

        Use this implementation before `nn.SyncBatchNorm` is fixed.
        It is slower than `nn.SyncBatchNorm`.
    """

    def forward(self, input):
        if dist.get_world_size() == 1 or not self.training:
            return super().forward(input)

        assert input.shape[0] > 0, 'SyncBN does not support empty inputs'
        C = input.shape[1]
        mean = torch.mean(input, dim=[0, 2, 3])
        meansqr = torch.mean(input * input, dim=[0, 2, 3])

        vec = torch.cat([mean, meansqr], dim=0)
        vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())

        mean, meansqr = torch.split(vec, C)
        var = meansqr - mean * mean
        self.running_mean += self.momentum * (
            mean.detach() - self.running_mean)
        self.running_var += self.momentum * (var.detach() - self.running_var)

        invstd = torch.rsqrt(var + self.eps)
        scale = self.weight * invstd
        bias = self.bias - mean * scale
        scale = scale.reshape(1, -1, 1, 1)
        bias = bias.reshape(1, -1, 1, 1)
        return input * scale + bias