conv_module.py 5.41 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
import warnings

import torch.nn as nn
Kai Chen's avatar
Kai Chen committed
4
from mmcv.cnn import kaiming_init, constant_init
Kai Chen's avatar
Kai Chen committed
5

6
from .conv_ws import ConvWS2d
Kai Chen's avatar
Kai Chen committed
7
8
from .norm import build_norm_layer

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
conv_cfg = {
    'Conv': nn.Conv2d,
    'ConvWS': ConvWS2d,
    # TODO: octave conv
}


def build_conv_layer(cfg, *args, **kwargs):
    """ Build convolution layer

    Args:
        cfg (None or dict): cfg should contain:
            type (str): identify conv layer type.
            layer args: args needed to instantiate a conv layer.

    Returns:
        layer (nn.Module): created conv layer
    """
    if cfg is None:
        cfg_ = dict(type='Conv')
    else:
        assert isinstance(cfg, dict) and 'type' in cfg
        cfg_ = cfg.copy()

    layer_type = cfg_.pop('type')
    if layer_type not in conv_cfg:
        raise KeyError('Unrecognized norm type {}'.format(layer_type))
    else:
        conv_layer = conv_cfg[layer_type]

    layer = conv_layer(*args, **kwargs, **cfg_)

    return layer

Kai Chen's avatar
Kai Chen committed
43
44

class ConvModule(nn.Module):
Kai Chen's avatar
Kai Chen committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    """Conv-Norm-Activation block.

    Args:
        in_channels (int): Same as nn.Conv2d.
        out_channels (int): Same as nn.Conv2d.
        kernel_size (int or tuple[int]): Same as nn.Conv2d.
        stride (int or tuple[int]): Same as nn.Conv2d.
        padding (int or tuple[int]): Same as nn.Conv2d.
        dilation (int or tuple[int]): Same as nn.Conv2d.
        groups (int): Same as nn.Conv2d.
        bias (bool or str): If specified as `auto`, it will be decided by the
            norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
            False.
        conv_cfg (dict): Config dict for convolution layer.
        norm_cfg (dict): Config dict for normalization layer.
        activation (str or None): Activation type, "ReLU" by default.
        inplace (bool): Whether to use inplace mode for activation.
        activate_last (bool): Whether to apply the activation layer in the
            last. (Do not use this flag since the behavior and api may be
            changed in the future.)
    """
Kai Chen's avatar
Kai Chen committed
66
67
68
69
70
71
72
73
74

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
Kai Chen's avatar
Kai Chen committed
75
                 bias='auto',
76
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
77
                 norm_cfg=None,
Kai Chen's avatar
Kai Chen committed
78
79
80
81
                 activation='relu',
                 inplace=True,
                 activate_last=True):
        super(ConvModule, self).__init__()
82
        assert conv_cfg is None or isinstance(conv_cfg, dict)
Kai Chen's avatar
Kai Chen committed
83
84
85
        assert norm_cfg is None or isinstance(norm_cfg, dict)
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
Kai Chen's avatar
Kai Chen committed
86
        self.activation = activation
Kai Chen's avatar
Kai Chen committed
87
        self.inplace = inplace
Kai Chen's avatar
Kai Chen committed
88
89
        self.activate_last = activate_last

Kai Chen's avatar
Kai Chen committed
90
91
92
93
94
95
96
        self.with_norm = norm_cfg is not None
        self.with_activatation = activation is not None
        # if the conv layer is before a norm layer, bias is unnecessary.
        if bias == 'auto':
            bias = False if self.with_norm else True
        self.with_bias = bias

Kai Chen's avatar
Kai Chen committed
97
98
99
        if self.with_norm and self.with_bias:
            warnings.warn('ConvModule has norm and bias at the same time')

Kai Chen's avatar
Kai Chen committed
100
        # build convolution layer
101
102
103
104
105
106
107
108
109
110
        self.conv = build_conv_layer(
            conv_cfg,
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias)
Kai Chen's avatar
Kai Chen committed
111
        # export the attributes of self.conv to a higher level for convenience
Kai Chen's avatar
Kai Chen committed
112
113
114
115
116
117
118
119
120
121
        self.in_channels = self.conv.in_channels
        self.out_channels = self.conv.out_channels
        self.kernel_size = self.conv.kernel_size
        self.stride = self.conv.stride
        self.padding = self.conv.padding
        self.dilation = self.conv.dilation
        self.transposed = self.conv.transposed
        self.output_padding = self.conv.output_padding
        self.groups = self.conv.groups

Kai Chen's avatar
Kai Chen committed
122
        # build normalization layers
Kai Chen's avatar
Kai Chen committed
123
        if self.with_norm:
Kai Chen's avatar
Kai Chen committed
124
            norm_channels = out_channels if self.activate_last else in_channels
Kai Chen's avatar
Kai Chen committed
125
            self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
ThangVu's avatar
ThangVu committed
126
            self.add_module(self.norm_name, norm)
Kai Chen's avatar
Kai Chen committed
127

Kai Chen's avatar
Kai Chen committed
128
        # build activation layer
Kai Chen's avatar
Kai Chen committed
129
        if self.with_activatation:
Kai Chen's avatar
Kai Chen committed
130
131
132
            if self.activation not in ['relu']:
                raise ValueError('{} is currently not supported.'.format(
                    self.activation))
Kai Chen's avatar
Kai Chen committed
133
134
135
            if self.activation == 'relu':
                self.activate = nn.ReLU(inplace=inplace)

Kai Chen's avatar
Kai Chen committed
136
        # Use msra init by default
Kai Chen's avatar
Kai Chen committed
137
138
        self.init_weights()

ThangVu's avatar
ThangVu committed
139
140
141
142
    @property
    def norm(self):
        return getattr(self, self.norm_name)

Kai Chen's avatar
Kai Chen committed
143
144
    def init_weights(self):
        nonlinearity = 'relu' if self.activation is None else self.activation
Kai Chen's avatar
Kai Chen committed
145
        kaiming_init(self.conv, nonlinearity=nonlinearity)
Kai Chen's avatar
Kai Chen committed
146
        if self.with_norm:
Kai Chen's avatar
Kai Chen committed
147
            constant_init(self.norm, 1, bias=0)
Kai Chen's avatar
Kai Chen committed
148
149
150
151
152

    def forward(self, x, activate=True, norm=True):
        if self.activate_last:
            x = self.conv(x)
            if norm and self.with_norm:
ThangVu's avatar
ThangVu committed
153
                x = self.norm(x)
Kai Chen's avatar
Kai Chen committed
154
155
156
            if activate and self.with_activatation:
                x = self.activate(x)
        else:
Kai Chen's avatar
Kai Chen committed
157
            # WARN: this may be removed or modified
Kai Chen's avatar
Kai Chen committed
158
            if norm and self.with_norm:
ThangVu's avatar
ThangVu committed
159
                x = self.norm(x)
Kai Chen's avatar
Kai Chen committed
160
161
162
163
            if activate and self.with_activatation:
                x = self.activate(x)
            x = self.conv(x)
        return x