resnet.py 11.5 KB
Newer Older
1
2
import logging

Kai Chen's avatar
Kai Chen committed
3
4
import torch.nn as nn
import torch.utils.checkpoint as cp
Kai Chen's avatar
Kai Chen committed
5
6

from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
7
from mmcv.runner import load_checkpoint
8
from ..utils import build_norm_layer
Kai Chen's avatar
Kai Chen committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


def conv3x3(in_planes, out_planes, stride=1, dilation=1):
    "3x3 convolution with padding"
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        dilation=dilation,
        bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
32
                 style='pytorch',
33
                 with_cp=False,
ThangVu's avatar
ThangVu committed
34
                 normalize=dict(type='BN')):
Kai Chen's avatar
Kai Chen committed
35
36
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride, dilation)
37

ThangVu's avatar
ThangVu committed
38
39
40
41
42
        # build_norm_layer return: (norm_name, norm_layer)
        self.norm1, norm1 = build_norm_layer(normalize, planes, postfix=1)
        self.norm2, norm2 = build_norm_layer(normalize, planes, postfix=2)
        self.add_module(self.norm1, norm1)
        self.add_module(self.norm2, norm2)
43

Kai Chen's avatar
Kai Chen committed
44
45
46
47
48
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
49
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
50
51

    def forward(self, x):
pangjm's avatar
pangjm committed
52
        identity = x
Kai Chen's avatar
Kai Chen committed
53
54

        out = self.conv1(x)
ThangVu's avatar
ThangVu committed
55
        out = getattr(self, self.norm1)(out)
Kai Chen's avatar
Kai Chen committed
56
57
58
        out = self.relu(out)

        out = self.conv2(out)
ThangVu's avatar
ThangVu committed
59
        out = getattr(self, self.norm2)(out)
Kai Chen's avatar
Kai Chen committed
60
61

        if self.downsample is not None:
pangjm's avatar
pangjm committed
62
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
63

pangjm's avatar
pangjm committed
64
        out += identity
Kai Chen's avatar
Kai Chen committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
79
                 style='pytorch',
80
                 with_cp=False,
ThangVu's avatar
ThangVu committed
81
                 normalize=dict(type='BN')):
pangjm's avatar
pangjm committed
82
        """Bottleneck block for ResNet.
83
84
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
85
86
        """
        super(Bottleneck, self).__init__()
87
        assert style in ['pytorch', 'caffe']
pangjm's avatar
pangjm committed
88
89
        self.inplanes = inplanes
        self.planes = planes
90
        if style == 'pytorch':
pangjm's avatar
pangjm committed
91
92
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
93
        else:
pangjm's avatar
pangjm committed
94
95
            self.conv1_stride = stride
            self.conv2_stride = 1
Kai Chen's avatar
Kai Chen committed
96
        self.conv1 = nn.Conv2d(
pangjm's avatar
pangjm committed
97
98
99
100
101
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
Kai Chen's avatar
Kai Chen committed
102
103
104
105
        self.conv2 = nn.Conv2d(
            planes,
            planes,
            kernel_size=3,
pangjm's avatar
pangjm committed
106
            stride=self.conv2_stride,
Kai Chen's avatar
Kai Chen committed
107
108
109
110
            padding=dilation,
            dilation=dilation,
            bias=False)

ThangVu's avatar
ThangVu committed
111
112
113
114
115
116
117
118
        # build_norm_layer return: (norm_name, norm_layer)
        self.norm1, norm1 = build_norm_layer(normalize, planes, postfix=1)
        self.norm2, norm2 = build_norm_layer(normalize, planes, postfix=2)
        self.norm3, norm3 = build_norm_layer(normalize, planes*self.expansion,
                                             postfix=3)
        self.add_module(self.norm1, norm1)
        self.add_module(self.norm2, norm2)
        self.add_module(self.norm3, norm3)
119

Kai Chen's avatar
Kai Chen committed
120
121
122
123
124
125
126
        self.conv3 = nn.Conv2d(
            planes, planes * self.expansion, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
        self.with_cp = with_cp
127
        self.normalize = normalize
Kai Chen's avatar
Kai Chen committed
128
129
130
131

    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
132
            identity = x
Kai Chen's avatar
Kai Chen committed
133
134

            out = self.conv1(x)
ThangVu's avatar
ThangVu committed
135
            out = getattr(self, self.norm1)(out)
Kai Chen's avatar
Kai Chen committed
136
137
138
            out = self.relu(out)

            out = self.conv2(out)
ThangVu's avatar
ThangVu committed
139
            out = getattr(self, self.norm2)(out)
Kai Chen's avatar
Kai Chen committed
140
141
142
            out = self.relu(out)

            out = self.conv3(out)
ThangVu's avatar
ThangVu committed
143
            out = getattr(self, self.norm3)(out)
Kai Chen's avatar
Kai Chen committed
144
145

            if self.downsample is not None:
pangjm's avatar
pangjm committed
146
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
147

pangjm's avatar
pangjm committed
148
            out += identity
Kai Chen's avatar
Kai Chen committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
168
                   style='pytorch',
169
                   with_cp=False,
ThangVu's avatar
ThangVu committed
170
                   normalize=dict(type='BN')):
Kai Chen's avatar
Kai Chen committed
171
172
173
174
175
176
177
178
179
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
ThangVu's avatar
ThangVu committed
180
            build_norm_layer(normalize, planes * block.expansion)[1],
Kai Chen's avatar
Kai Chen committed
181
182
183
184
185
186
187
188
189
190
191
        )

    layers = []
    layers.append(
        block(
            inplanes,
            planes,
            stride,
            dilation,
            downsample,
            style=style,
192
            with_cp=with_cp,
ThangVu's avatar
ThangVu committed
193
            normalize=normalize))
Kai Chen's avatar
Kai Chen committed
194
195
196
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
197
198
            block(inplanes, planes, 1, dilation, style=style,
                  with_cp=with_cp, normalize=normalize))
Kai Chen's avatar
Kai Chen committed
199
200
201
202

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
203
204
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
205

Kai Chen's avatar
Kai Chen committed
206
207
208
209
210
211
212
213
214
215
216
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
            not freezing any parameters.
thangvu's avatar
thangvu committed
217
218
219
        normalize (dict): dictionary to construct norm layer. Additionally,
            eval mode and gradent freezing are controlled by
            eval (bool) and frozen (bool) respectively.
Kai Chen's avatar
Kai Chen committed
220
221
222
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
    """
Kai Chen's avatar
Kai Chen committed
223

Kai Chen's avatar
Kai Chen committed
224
225
226
227
228
229
230
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
231
232

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
233
234
                 depth,
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
235
236
237
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
238
                 style='pytorch',
ThangVu's avatar
ThangVu committed
239
                 frozen_stages=-1,
240
241
                 normalize=dict(
                     type='BN',
ThangVu's avatar
ThangVu committed
242
                     eval_mode=True,
thangvu's avatar
thangvu committed
243
                     frozen=False),
ThangVu's avatar
ThangVu committed
244
245
                 with_cp=False,
                 zero_init_residual=True):
Kai Chen's avatar
Kai Chen committed
246
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
247
248
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
249
250
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
251
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
252
253
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
254
        assert len(strides) == len(dilations) == num_stages
Kai Chen's avatar
Kai Chen committed
255
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
256
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
257
        self.style = style
ThangVu's avatar
ThangVu committed
258
259
        self.frozen_stages = frozen_stages
        assert (isinstance(normalize, dict) and 'eval_mode' in normalize
thangvu's avatar
thangvu committed
260
                and 'frozen' in normalize)
ThangVu's avatar
ThangVu committed
261
        self.norm_eval = normalize.pop('eval_mode')
262
        self.normalize = normalize
ThangVu's avatar
ThangVu committed
263
264
        self.with_cp = with_cp
        self.zero_init_residual = zero_init_residual
pangjm's avatar
pangjm committed
265
266
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
267
        self.inplanes = 64
pangjm's avatar
pangjm committed
268

thangvu's avatar
thangvu committed
269
        self._make_stem_layer()
Kai Chen's avatar
Kai Chen committed
270

Kai Chen's avatar
Kai Chen committed
271
        self.res_layers = []
pangjm's avatar
pangjm committed
272
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
273
274
275
276
            stride = strides[i]
            dilation = dilations[i]
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
277
                self.block,
Kai Chen's avatar
Kai Chen committed
278
279
280
281
282
283
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
284
                with_cp=with_cp,
ThangVu's avatar
ThangVu committed
285
                normalize=normalize)
pangjm's avatar
pangjm committed
286
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
287
            layer_name = 'layer{}'.format(i + 1)
288
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
289
290
            self.res_layers.append(layer_name)

ThangVu's avatar
ThangVu committed
291
292
        self._freeze_stages()

pangjm's avatar
pangjm committed
293
294
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
295

thangvu's avatar
thangvu committed
296
297
298
    def _make_stem_layer(self):
        self.conv1 = nn.Conv2d(
            3, 64, kernel_size=7, stride=2, padding=3, bias=False)
ThangVu's avatar
ThangVu committed
299
300
301
        self.stem_norm, stem_norm = build_norm_layer(self.normalize,
                                                     64, postfix=1)
        self.add_module(self.stem_norm, stem_norm)
thangvu's avatar
thangvu committed
302
303
304
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

ThangVu's avatar
ThangVu committed
305
306
307
        if self.frozen_stages >= 0:
            for m in [self.conv1, stem_norm]:
                for param in m.parameters():
thangvu's avatar
thangvu committed
308
309
                    param.requires_grad = False

ThangVu's avatar
ThangVu committed
310
311
312
313
314
315
    def _freeze_stages(self):
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
            for param in m.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
316
317
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
318
319
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
320
321
322
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
323
                    kaiming_init(m)
ThangVu's avatar
ThangVu committed
324
                elif isinstance(m, (nn.BatchNorm, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
325
                    constant_init(m, 1)
326
327

            # zero init for last norm layer https://arxiv.org/abs/1706.02677
ThangVu's avatar
ThangVu committed
328
329
330
331
332
            if self.zero_init_residual:
                for m in self.modules():
                    if isinstance(m, (Bottleneck, BasicBlock)):
                        last_norm = getattr(m, m.norm_names[-1])
                        constant_init(last_norm, 0)
Kai Chen's avatar
Kai Chen committed
333
334
335
336
337
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
ThangVu's avatar
ThangVu committed
338
        x = getattr(self, self.stem_norm)(x)
Kai Chen's avatar
Kai Chen committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
        if len(outs) == 1:
            return outs[0]
        else:
            return tuple(outs)

    def train(self, mode=True):
        super(ResNet, self).train(mode)
thangvu's avatar
thangvu committed
354
        if mode and self.norm_eval:
ThangVu's avatar
ThangVu committed
355
            for m in self.modules():
thangvu's avatar
thangvu committed
356
                # trick: eval have effect on BatchNorm only
ThangVu's avatar
ThangVu committed
357
358
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()