conv_module.py 5.61 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
import warnings

import torch.nn as nn
4
from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
5

6
from .conv_ws import ConvWS2d
Kai Chen's avatar
Kai Chen committed
7
8
from .norm import build_norm_layer

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
conv_cfg = {
    'Conv': nn.Conv2d,
    'ConvWS': ConvWS2d,
    # TODO: octave conv
}


def build_conv_layer(cfg, *args, **kwargs):
    """ Build convolution layer

    Args:
        cfg (None or dict): cfg should contain:
            type (str): identify conv layer type.
            layer args: args needed to instantiate a conv layer.

    Returns:
        layer (nn.Module): created conv layer
    """
    if cfg is None:
        cfg_ = dict(type='Conv')
    else:
        assert isinstance(cfg, dict) and 'type' in cfg
        cfg_ = cfg.copy()

    layer_type = cfg_.pop('type')
    if layer_type not in conv_cfg:
        raise KeyError('Unrecognized norm type {}'.format(layer_type))
    else:
        conv_layer = conv_cfg[layer_type]

    layer = conv_layer(*args, **kwargs, **cfg_)

    return layer

Kai Chen's avatar
Kai Chen committed
43
44

class ConvModule(nn.Module):
45
    """A conv block that contains conv/norm/activation layers.
Kai Chen's avatar
Kai Chen committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

    Args:
        in_channels (int): Same as nn.Conv2d.
        out_channels (int): Same as nn.Conv2d.
        kernel_size (int or tuple[int]): Same as nn.Conv2d.
        stride (int or tuple[int]): Same as nn.Conv2d.
        padding (int or tuple[int]): Same as nn.Conv2d.
        dilation (int or tuple[int]): Same as nn.Conv2d.
        groups (int): Same as nn.Conv2d.
        bias (bool or str): If specified as `auto`, it will be decided by the
            norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
            False.
        conv_cfg (dict): Config dict for convolution layer.
        norm_cfg (dict): Config dict for normalization layer.
        activation (str or None): Activation type, "ReLU" by default.
        inplace (bool): Whether to use inplace mode for activation.
62
63
64
        order (tuple[str]): The order of conv/norm/activation layers. It is a
            sequence of "conv", "norm" and "act". Examples are
            ("conv", "norm", "act") and ("act", "conv", "norm").
Kai Chen's avatar
Kai Chen committed
65
    """
Kai Chen's avatar
Kai Chen committed
66
67
68
69
70
71
72
73
74

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
Kai Chen's avatar
Kai Chen committed
75
                 bias='auto',
76
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
77
                 norm_cfg=None,
Kai Chen's avatar
Kai Chen committed
78
79
                 activation='relu',
                 inplace=True,
80
                 order=('conv', 'norm', 'act')):
Kai Chen's avatar
Kai Chen committed
81
        super(ConvModule, self).__init__()
82
        assert conv_cfg is None or isinstance(conv_cfg, dict)
Kai Chen's avatar
Kai Chen committed
83
84
85
        assert norm_cfg is None or isinstance(norm_cfg, dict)
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
Kai Chen's avatar
Kai Chen committed
86
        self.activation = activation
Kai Chen's avatar
Kai Chen committed
87
        self.inplace = inplace
88
89
90
        self.order = order
        assert isinstance(self.order, tuple) and len(self.order) == 3
        assert set(order) == set(['conv', 'norm', 'act'])
Kai Chen's avatar
Kai Chen committed
91

Kai Chen's avatar
Kai Chen committed
92
93
94
95
96
97
98
        self.with_norm = norm_cfg is not None
        self.with_activatation = activation is not None
        # if the conv layer is before a norm layer, bias is unnecessary.
        if bias == 'auto':
            bias = False if self.with_norm else True
        self.with_bias = bias

Kai Chen's avatar
Kai Chen committed
99
100
101
        if self.with_norm and self.with_bias:
            warnings.warn('ConvModule has norm and bias at the same time')

Kai Chen's avatar
Kai Chen committed
102
        # build convolution layer
103
104
105
106
107
108
109
110
111
112
        self.conv = build_conv_layer(
            conv_cfg,
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias)
Kai Chen's avatar
Kai Chen committed
113
        # export the attributes of self.conv to a higher level for convenience
Kai Chen's avatar
Kai Chen committed
114
115
116
117
118
119
120
121
122
123
        self.in_channels = self.conv.in_channels
        self.out_channels = self.conv.out_channels
        self.kernel_size = self.conv.kernel_size
        self.stride = self.conv.stride
        self.padding = self.conv.padding
        self.dilation = self.conv.dilation
        self.transposed = self.conv.transposed
        self.output_padding = self.conv.output_padding
        self.groups = self.conv.groups

Kai Chen's avatar
Kai Chen committed
124
        # build normalization layers
Kai Chen's avatar
Kai Chen committed
125
        if self.with_norm:
126
127
128
129
130
            # norm layer is after conv layer
            if order.index('norm') > order.index('conv'):
                norm_channels = out_channels
            else:
                norm_channels = in_channels
Kai Chen's avatar
Kai Chen committed
131
            self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
ThangVu's avatar
ThangVu committed
132
            self.add_module(self.norm_name, norm)
Kai Chen's avatar
Kai Chen committed
133

Kai Chen's avatar
Kai Chen committed
134
        # build activation layer
Kai Chen's avatar
Kai Chen committed
135
        if self.with_activatation:
136
            # TODO: introduce `act_cfg` and supports more activation layers
Kai Chen's avatar
Kai Chen committed
137
138
139
            if self.activation not in ['relu']:
                raise ValueError('{} is currently not supported.'.format(
                    self.activation))
Kai Chen's avatar
Kai Chen committed
140
141
142
            if self.activation == 'relu':
                self.activate = nn.ReLU(inplace=inplace)

Kai Chen's avatar
Kai Chen committed
143
        # Use msra init by default
Kai Chen's avatar
Kai Chen committed
144
145
        self.init_weights()

ThangVu's avatar
ThangVu committed
146
147
148
149
    @property
    def norm(self):
        return getattr(self, self.norm_name)

Kai Chen's avatar
Kai Chen committed
150
151
    def init_weights(self):
        nonlinearity = 'relu' if self.activation is None else self.activation
Kai Chen's avatar
Kai Chen committed
152
        kaiming_init(self.conv, nonlinearity=nonlinearity)
Kai Chen's avatar
Kai Chen committed
153
        if self.with_norm:
Kai Chen's avatar
Kai Chen committed
154
            constant_init(self.norm, 1, bias=0)
Kai Chen's avatar
Kai Chen committed
155
156

    def forward(self, x, activate=True, norm=True):
157
158
159
160
        for layer in self.order:
            if layer == 'conv':
                x = self.conv(x)
            elif layer == 'norm' and norm and self.with_norm:
ThangVu's avatar
ThangVu committed
161
                x = self.norm(x)
162
            elif layer == 'act' and activate and self.with_activatation:
Kai Chen's avatar
Kai Chen committed
163
164
                x = self.activate(x)
        return x