resnet.py 15.7 KB
Newer Older
1
import logging
2
import pickle
3

4
import torch
Kai Chen's avatar
Kai Chen committed
5
6
import torch.nn as nn
import torch.utils.checkpoint as cp
Kai Chen's avatar
Kai Chen committed
7
8

from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
9
from mmcv.runner import load_checkpoint
10
from ..utils import build_norm_layer
Kai Chen's avatar
Kai Chen committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33


def conv3x3(in_planes, out_planes, stride=1, dilation=1):
    "3x3 convolution with padding"
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        dilation=dilation,
        bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
34
                 style='pytorch',
35
36
                 with_cp=False,
                 normalize=dict(type='GN')):
Kai Chen's avatar
Kai Chen committed
37
38
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride, dilation)
39
40
41
42
43
44
45
46
47

        norm_layers = []
        norm_layers.append(build_norm_layer(normalize, planes))
        norm_layers.append(build_norm_layer(normalize, planes))
        self.norm_names = (['gn1', 'gn2'] if normalize['type'] == 'GN'
                           else ['bn1', 'bn2'])
        for name, layer in zip(self.norm_names, norm_layers):
            self.add_module(name, layer)

Kai Chen's avatar
Kai Chen committed
48
49
50
51
52
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
53
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
54
55
56
57
58

    def forward(self, x):
        residual = x

        out = self.conv1(x)
59
        out = getattr(self, self.norm_names[0])(out)
Kai Chen's avatar
Kai Chen committed
60
61
62
        out = self.relu(out)

        out = self.conv2(out)
63
        out = getattr(self, self.norm_names[1])(out)
Kai Chen's avatar
Kai Chen committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
83
                 style='pytorch',
84
85
                 with_cp=False,
                 normalize=dict(type='BN')):
86
87
88
        """Bottleneck block.
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
89
90
        """
        super(Bottleneck, self).__init__()
91
92
        assert style in ['pytorch', 'caffe']
        if style == 'pytorch':
Kai Chen's avatar
Kai Chen committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
            conv1_stride = 1
            conv2_stride = stride
        else:
            conv1_stride = stride
            conv2_stride = 1
        self.conv1 = nn.Conv2d(
            inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False)
        self.conv2 = nn.Conv2d(
            planes,
            planes,
            kernel_size=3,
            stride=conv2_stride,
            padding=dilation,
            dilation=dilation,
            bias=False)

109
110
111
112
113
114
115
116
117
        norm_layers = []
        norm_layers.append(build_norm_layer(normalize, planes))
        norm_layers.append(build_norm_layer(normalize, planes))
        norm_layers.append(build_norm_layer(normalize, planes*self.expansion))
        self.norm_names = (['gn1', 'gn2', 'gn3'] if normalize['type'] == 'GN'
                           else ['bn1', 'bn2', 'bn3'])
        for name, layer in zip(self.norm_names, norm_layers):
            self.add_module(name, layer)

Kai Chen's avatar
Kai Chen committed
118
119
120
121
122
123
124
        self.conv3 = nn.Conv2d(
            planes, planes * self.expansion, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
        self.with_cp = with_cp
125
        self.normalize = normalize
Kai Chen's avatar
Kai Chen committed
126
127
128
129
130
131
132

    def forward(self, x):

        def _inner_forward(x):
            residual = x

            out = self.conv1(x)
133
            out = getattr(self, self.norm_names[0])(out)
Kai Chen's avatar
Kai Chen committed
134
135
136
            out = self.relu(out)

            out = self.conv2(out)
137
            out = getattr(self, self.norm_names[1])(out)
Kai Chen's avatar
Kai Chen committed
138
139
140
            out = self.relu(out)

            out = self.conv3(out)
141
            out = getattr(self, self.norm_names[2])(out)
Kai Chen's avatar
Kai Chen committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

            if self.downsample is not None:
                residual = self.downsample(x)

            out += residual

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
166
                   style='pytorch',
167
168
                   with_cp=False,
                   normalize=dict(type='BN')):
Kai Chen's avatar
Kai Chen committed
169
170
171
172
173
174
175
176
177
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
178
            build_norm_layer(normalize, planes * block.expansion),
Kai Chen's avatar
Kai Chen committed
179
180
181
182
183
184
185
186
187
188
189
        )

    layers = []
    layers.append(
        block(
            inplanes,
            planes,
            stride,
            dilation,
            downsample,
            style=style,
190
191
            with_cp=with_cp,
            normalize=normalize))
Kai Chen's avatar
Kai Chen committed
192
193
194
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
195
196
            block(inplanes, planes, 1, dilation, style=style,
                  with_cp=with_cp, normalize=normalize))
Kai Chen's avatar
Kai Chen committed
197
198
199
200

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
201
202
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
203

Kai Chen's avatar
Kai Chen committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
            not freezing any parameters.
        bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze
            running stats (mean and var).
        bn_frozen (bool): Whether to freeze weight and bias of BN layers.
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
    """
Kai Chen's avatar
Kai Chen committed
221

Kai Chen's avatar
Kai Chen committed
222
223
224
225
226
227
228
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
229
230

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
231
232
                 depth,
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
233
234
235
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
236
                 style='pytorch',
237
238
239
240
241
                 normalize=dict(
                     type='BN',
                     frozen_stages=-1,
                     bn_eval=True,
                     bn_frozen=False),
Kai Chen's avatar
Kai Chen committed
242
                 with_cp=False):
Kai Chen's avatar
Kai Chen committed
243
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
244
245
246
247
248
249
250
251
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
        assert num_stages >= 1 and num_stages <= 4
        block, stage_blocks = self.arch_settings[depth]
        stage_blocks = stage_blocks[:num_stages]
        assert len(strides) == len(dilations) == num_stages
        assert max(out_indices) < num_stages

252
253
254
255
256
257
258
259
        assert isinstance(normalize, dict) and 'type' in normalize
        assert normalize['type'] in ['BN', 'GN']
        if normalize['type'] == 'GN':
            assert 'num_groups' in normalize
        else:
            assert (set(['type', 'frozen_stages', 'bn_eval', 'bn_frozen'])
                    == set(normalize))

Kai Chen's avatar
Kai Chen committed
260
261
        self.out_indices = out_indices
        self.style = style
Kai Chen's avatar
Kai Chen committed
262
        self.with_cp = with_cp
263
264
265
266
267
        if normalize['type'] == 'BN':
            self.frozen_stages = normalize['frozen_stages']
            self.bn_eval = normalize['bn_eval']
            self.bn_frozen = normalize['bn_frozen']
        self.normalize = normalize
Kai Chen's avatar
Kai Chen committed
268

Kai Chen's avatar
Kai Chen committed
269
270
271
        self.inplanes = 64
        self.conv1 = nn.Conv2d(
            3, 64, kernel_size=7, stride=2, padding=3, bias=False)
272
273
274
        stem_norm = build_norm_layer(normalize, 64)
        self.stem_norm_name = 'gn1' if normalize['type'] == 'GN' else 'bn1'
        self.add_module(self.stem_norm_name, stem_norm)
Kai Chen's avatar
Kai Chen committed
275
276
277
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

Kai Chen's avatar
Kai Chen committed
278
279
        self.res_layers = []
        for i, num_blocks in enumerate(stage_blocks):
Kai Chen's avatar
Kai Chen committed
280
281
282
283
284
285
286
287
288
289
290
            stride = strides[i]
            dilation = dilations[i]
            planes = 64 * 2**i
            res_layer = make_res_layer(
                block,
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
291
292
                with_cp=with_cp,
                normalize=normalize)
Kai Chen's avatar
Kai Chen committed
293
            self.inplanes = planes * block.expansion
Kai Chen's avatar
Kai Chen committed
294
            layer_name = 'layer{}'.format(i + 1)
295
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
296
297
            self.res_layers.append(layer_name)

Kai Chen's avatar
Kai Chen committed
298
        self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
pangjm's avatar
pangjm committed
299

Kai Chen's avatar
Kai Chen committed
300
301
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
302
303
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
304
305
306
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
307
                    kaiming_init(m)
Kai Chen's avatar
Kai Chen committed
308
                elif isinstance(m, nn.BatchNorm2d):
Kai Chen's avatar
Kai Chen committed
309
                    constant_init(m, 1)
310
311
312
313
314
315

            # zero init for last norm layer https://arxiv.org/abs/1706.02677
            for m in self.modules():
                if isinstance(m, Bottleneck) or isinstance(m, BasicBlock):
                    last_norm = getattr(m, m.norm_names[-1])
                    constant_init(last_norm, 0)
Kai Chen's avatar
Kai Chen committed
316
317
318
319
320
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
321
        x = getattr(self, self.stem_norm_name)(x)
Kai Chen's avatar
Kai Chen committed
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
        if len(outs) == 1:
            return outs[0]
        else:
            return tuple(outs)

    def train(self, mode=True):
        super(ResNet, self).train(mode)
337
338
339
340
341
342
343
344
345
346
        if self.normalize['type'] == 'BN':
            if self.bn_eval:
                for m in self.modules():
                    if isinstance(m, nn.BatchNorm2d):
                        m.eval()
                        if self.bn_frozen:
                            for params in m.parameters():
                                params.requires_grad = False
            if mode and self.frozen_stages >= 0:
                for param in self.conv1.parameters():
Kai Chen's avatar
Kai Chen committed
347
                    param.requires_grad = False
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
                for param in self.bn1.parameters():
                    param.requires_grad = False
                self.bn1.eval()
                self.bn1.weight.requires_grad = False
                self.bn1.bias.requires_grad = False
                for i in range(1, self.frozen_stages + 1):
                    mod = getattr(self, 'layer{}'.format(i))
                    mod.eval()
                    for param in mod.parameters():
                        param.requires_grad = False


class ResNetClassifier(ResNet):
    def __init__(self,
                 depth,
                 num_stages=4,
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
                 style='pytorch',
                 normalize=dict(
                     type='BN',
                     frozen_stages=-1,
                     bn_eval=True,
                     bn_frozen=False),
                 with_cp=False,
                 num_classes=1000):
        super(ResNetClassifier, self).__init__(depth,
                                               num_stages=num_stages,
                                               strides=strides,
                                               dilations=dilations,
                                               out_indices=out_indices,
                                               style=style,
                                               normalize=normalize,
                                               with_cp=with_cp)
        _, self.stage_blocks = self.arch_settings[depth]
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        expansion = 1 if depth == 18 else 4
        self.fc = nn.Linear(512 * expansion, num_classes)

        self.init_weights()

    # TODO can be removed after tested
    def load_caffe2_weight(self, cf_path):
        norm = 'gn' if self.normalize['type'] == 'GN' else 'bn'
        mapping = {}

        for layer, blocks_in_layer in enumerate(self.stage_blocks, 1):
            for blk in range(blocks_in_layer):
                cf_prefix = 'res%d_%d_' % (layer + 1, blk)
                py_prefix = 'layer%d.%d.' % (layer, blk)

                # conv branch
                for i, a in zip([1, 2, 3], ['a', 'b', 'c']):
                    cf_full = cf_prefix + 'branch2%s_' % a
                    mapping[py_prefix + 'conv%d.weight' % i] = cf_full + 'w'
                    mapping[py_prefix + norm + '%d.weight' % i] \
                        = cf_full + norm + '_s'
                    mapping[py_prefix + norm + '%d.bias' % i] \
                        = cf_full + norm + '_b'

            # downsample branch
            cf_full = 'res%d_0_branch1_' % (layer + 1)
            py_full = 'layer%d.0.downsample.' % layer
            mapping[py_full + '0.weight'] = cf_full + 'w'
            mapping[py_full + '1.weight'] = cf_full + norm + '_s'
            mapping[py_full + '1.bias'] = cf_full + norm + '_b'

        # stem layers and last fc layer
        if self.normalize['type'] == 'GN':
            mapping['conv1.weight'] = 'conv1_w'
            mapping['gn1.weight'] = 'conv1_gn_s'
            mapping['gn1.bias'] = 'conv1_gn_b'
            mapping['fc.weight'] = 'pred_w'
            mapping['fc.bias'] = 'pred_b'
        else:
            mapping['conv1.weight'] = 'conv1_w'
            mapping['bn1.weight'] = 'res_conv1_bn_s'
            mapping['bn1.bias'] = 'res_conv1_bn_b'
            mapping['fc.weight'] = 'fc1000_w'
            mapping['fc.bias'] = 'fc1000_b'

        # load state dict
        py_state = self.state_dict()
        with open(cf_path, 'rb') as f:
            cf_state = pickle.load(f, encoding='latin1')
            if 'blobs' in cf_state:
                cf_state = cf_state['blobs']
            for py_k, cf_k in mapping.items():
                print('Loading {} to {}'.format(cf_k, py_k))
                assert py_k in py_state and cf_k in cf_state
                py_state[py_k] = torch.Tensor(cf_state[cf_k])
        self.load_state_dict(py_state)

    def forward(self, x):
        x = self.conv1(x)
        x = getattr(self, self.stem_norm_name)(x)
        x = self.relu(x)
        x = self.maxpool(x)
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x