resnet.py 15.1 KB
Newer Older
1
2
import logging

Kai Chen's avatar
Kai Chen committed
3
4
import torch.nn as nn
import torch.utils.checkpoint as cp
5
from torch.nn.modules.batchnorm import _BatchNorm
Kai Chen's avatar
Kai Chen committed
6
7

from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
8
from mmcv.runner import load_checkpoint
Kai Chen's avatar
Kai Chen committed
9
10
11

from mmdet.ops import DeformConv, ModulatedDeformConv
from ..registry import BACKBONES
12
from ..utils import build_conv_layer, build_norm_layer
Kai Chen's avatar
Kai Chen committed
13
14
15
16
17
18
19
20
21
22
23


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
24
                 style='pytorch',
25
                 with_cp=False,
26
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
27
                 norm_cfg=dict(type='BN'),
pangjm's avatar
pangjm committed
28
                 dcn=None):
Kai Chen's avatar
Kai Chen committed
29
        super(BasicBlock, self).__init__()
pangjm's avatar
pangjm committed
30
        assert dcn is None, "Not implemented yet."
31

Kai Chen's avatar
Kai Chen committed
32
33
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
34

35
36
37
38
39
40
41
42
43
        self.conv1 = build_conv_layer(
            conv_cfg,
            inplanes,
            planes,
            3,
            stride=stride,
            padding=dilation,
            dilation=dilation,
            bias=False)
ThangVu's avatar
ThangVu committed
44
        self.add_module(self.norm1_name, norm1)
45
46
47
48
49
        self.conv2 = build_conv_layer(
            conv_cfg,
            planes,
            planes,
            3,
liuzili97's avatar
liuzili97 committed
50
            padding=1,
51
            bias=False)
ThangVu's avatar
ThangVu committed
52
        self.add_module(self.norm2_name, norm2)
53

Kai Chen's avatar
Kai Chen committed
54
55
56
57
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
58
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
59

ThangVu's avatar
ThangVu committed
60
61
62
63
64
65
66
67
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

Kai Chen's avatar
Kai Chen committed
68
    def forward(self, x):
pangjm's avatar
pangjm committed
69
        identity = x
Kai Chen's avatar
Kai Chen committed
70
71

        out = self.conv1(x)
ThangVu's avatar
ThangVu committed
72
        out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
73
74
75
        out = self.relu(out)

        out = self.conv2(out)
ThangVu's avatar
ThangVu committed
76
        out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
77
78

        if self.downsample is not None:
pangjm's avatar
pangjm committed
79
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
80

pangjm's avatar
pangjm committed
81
        out += identity
Kai Chen's avatar
Kai Chen committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
96
                 style='pytorch',
97
                 with_cp=False,
98
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
99
                 norm_cfg=dict(type='BN'),
Kai Chen's avatar
Kai Chen committed
100
                 dcn=None):
pangjm's avatar
pangjm committed
101
        """Bottleneck block for ResNet.
102
103
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
104
105
        """
        super(Bottleneck, self).__init__()
106
        assert style in ['pytorch', 'caffe']
Kai Chen's avatar
Kai Chen committed
107
        assert dcn is None or isinstance(dcn, dict)
pangjm's avatar
pangjm committed
108
109
        self.inplanes = inplanes
        self.planes = planes
Kai Chen's avatar
Kai Chen committed
110
111
112
113
        self.stride = stride
        self.dilation = dilation
        self.style = style
        self.with_cp = with_cp
114
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
115
        self.norm_cfg = norm_cfg
Kai Chen's avatar
Kai Chen committed
116
117
        self.dcn = dcn
        self.with_dcn = dcn is not None
Kai Chen's avatar
Kai Chen committed
118
        if self.style == 'pytorch':
pangjm's avatar
pangjm committed
119
120
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
121
        else:
pangjm's avatar
pangjm committed
122
123
            self.conv1_stride = stride
            self.conv2_stride = 1
124

Kai Chen's avatar
Kai Chen committed
125
126
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
Kai Chen's avatar
Kai Chen committed
127
        self.norm3_name, norm3 = build_norm_layer(
Kai Chen's avatar
Kai Chen committed
128
            norm_cfg, planes * self.expansion, postfix=3)
129

130
131
        self.conv1 = build_conv_layer(
            conv_cfg,
pangjm's avatar
pangjm committed
132
133
134
135
136
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
137
        self.add_module(self.norm1_name, norm1)
Kai Chen's avatar
Kai Chen committed
138
139
140
141
142
143
        fallback_on_stride = False
        self.with_modulated_dcn = False
        if self.with_dcn:
            fallback_on_stride = dcn.get('fallback_on_stride', False)
            self.with_modulated_dcn = dcn.get('modulated', False)
        if not self.with_dcn or fallback_on_stride:
144
145
            self.conv2 = build_conv_layer(
                conv_cfg,
Kai Chen's avatar
Kai Chen committed
146
147
148
149
150
151
152
153
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                bias=False)
        else:
154
            assert conv_cfg is None, 'conv_cfg must be None for DCN'
Kai Chen's avatar
Kai Chen committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            deformable_groups = dcn.get('deformable_groups', 1)
            if not self.with_modulated_dcn:
                conv_op = DeformConv
                offset_channels = 18
            else:
                conv_op = ModulatedDeformConv
                offset_channels = 27
            self.conv2_offset = nn.Conv2d(
                planes,
                deformable_groups * offset_channels,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation)
            self.conv2 = conv_op(
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                deformable_groups=deformable_groups,
                bias=False)
ThangVu's avatar
ThangVu committed
178
        self.add_module(self.norm2_name, norm2)
179
180
181
182
183
184
        self.conv3 = build_conv_layer(
            conv_cfg,
            planes,
            planes * self.expansion,
            kernel_size=1,
            bias=False)
185
186
        self.add_module(self.norm3_name, norm3)

Kai Chen's avatar
Kai Chen committed
187
        self.relu = nn.ReLU(inplace=True)
188
        self.downsample = downsample
Kai Chen's avatar
Kai Chen committed
189

ThangVu's avatar
ThangVu committed
190
191
192
193
194
195
196
197
198
199
200
201
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    @property
    def norm3(self):
        return getattr(self, self.norm3_name)

Kai Chen's avatar
Kai Chen committed
202
203
204
    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
205
            identity = x
Kai Chen's avatar
Kai Chen committed
206
207

            out = self.conv1(x)
ThangVu's avatar
ThangVu committed
208
            out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
209
210
            out = self.relu(out)

Kai Chen's avatar
Kai Chen committed
211
212
213
214
215
216
217
218
219
220
            if not self.with_dcn:
                out = self.conv2(out)
            elif self.with_modulated_dcn:
                offset_mask = self.conv2_offset(out)
                offset = offset_mask[:, :18, :, :]
                mask = offset_mask[:, -9:, :, :].sigmoid()
                out = self.conv2(out, offset, mask)
            else:
                offset = self.conv2_offset(out)
                out = self.conv2(out, offset)
ThangVu's avatar
ThangVu committed
221
            out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
222
223
224
            out = self.relu(out)

            out = self.conv3(out)
ThangVu's avatar
ThangVu committed
225
            out = self.norm3(out)
Kai Chen's avatar
Kai Chen committed
226
227

            if self.downsample is not None:
pangjm's avatar
pangjm committed
228
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
229

pangjm's avatar
pangjm committed
230
            out += identity
Kai Chen's avatar
Kai Chen committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
250
                   style='pytorch',
251
                   with_cp=False,
252
                   conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
253
                   norm_cfg=dict(type='BN'),
Kai Chen's avatar
Kai Chen committed
254
                   dcn=None):
Kai Chen's avatar
Kai Chen committed
255
256
257
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
258
259
            build_conv_layer(
                conv_cfg,
Kai Chen's avatar
Kai Chen committed
260
261
262
263
264
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
Kai Chen's avatar
Kai Chen committed
265
            build_norm_layer(norm_cfg, planes * block.expansion)[1],
Kai Chen's avatar
Kai Chen committed
266
267
268
269
270
271
272
273
274
275
276
        )

    layers = []
    layers.append(
        block(
            inplanes,
            planes,
            stride,
            dilation,
            downsample,
            style=style,
277
            with_cp=with_cp,
278
            conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
279
            norm_cfg=norm_cfg,
Kai Chen's avatar
Kai Chen committed
280
            dcn=dcn))
Kai Chen's avatar
Kai Chen committed
281
282
283
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
Kai Chen's avatar
Kai Chen committed
284
285
286
287
288
289
290
            block(
                inplanes,
                planes,
                1,
                dilation,
                style=style,
                with_cp=with_cp,
291
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
292
                norm_cfg=norm_cfg,
Kai Chen's avatar
Kai Chen committed
293
                dcn=dcn))
Kai Chen's avatar
Kai Chen committed
294
295
296
297

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
298
@BACKBONES.register_module
Kai Chen's avatar
Kai Chen committed
299
300
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
301

Kai Chen's avatar
Kai Chen committed
302
303
304
305
306
307
308
309
310
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
311
312
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any parameters.
Kai Chen's avatar
Kai Chen committed
313
        norm_cfg (dict): dictionary to construct and config norm layer.
thangvu's avatar
thangvu committed
314
315
316
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only.
Kai Chen's avatar
Kai Chen committed
317
318
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
thangvu's avatar
thangvu committed
319
320
        zero_init_residual (bool): whether to use zero init for last norm layer
            in resblocks to let them behave as identity.
Kai Chen's avatar
Kai Chen committed
321
    """
Kai Chen's avatar
Kai Chen committed
322

Kai Chen's avatar
Kai Chen committed
323
324
325
326
327
328
329
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
330
331

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
332
333
                 depth,
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
334
335
336
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
337
                 style='pytorch',
ThangVu's avatar
ThangVu committed
338
                 frozen_stages=-1,
339
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
340
                 norm_cfg=dict(type='BN', requires_grad=True),
thangvu's avatar
thangvu committed
341
                 norm_eval=True,
Kai Chen's avatar
Kai Chen committed
342
343
                 dcn=None,
                 stage_with_dcn=(False, False, False, False),
ThangVu's avatar
ThangVu committed
344
345
                 with_cp=False,
                 zero_init_residual=True):
Kai Chen's avatar
Kai Chen committed
346
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
347
348
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
349
350
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
351
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
352
353
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
354
        assert len(strides) == len(dilations) == num_stages
Kai Chen's avatar
Kai Chen committed
355
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
356
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
357
        self.style = style
ThangVu's avatar
ThangVu committed
358
        self.frozen_stages = frozen_stages
359
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
360
        self.norm_cfg = norm_cfg
ThangVu's avatar
ThangVu committed
361
        self.with_cp = with_cp
thangvu's avatar
thangvu committed
362
        self.norm_eval = norm_eval
Kai Chen's avatar
Kai Chen committed
363
364
        self.dcn = dcn
        self.stage_with_dcn = stage_with_dcn
Kai Chen's avatar
Kai Chen committed
365
366
        if dcn is not None:
            assert len(stage_with_dcn) == num_stages
ThangVu's avatar
ThangVu committed
367
        self.zero_init_residual = zero_init_residual
pangjm's avatar
pangjm committed
368
369
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
370
        self.inplanes = 64
pangjm's avatar
pangjm committed
371

thangvu's avatar
thangvu committed
372
        self._make_stem_layer()
Kai Chen's avatar
Kai Chen committed
373

Kai Chen's avatar
Kai Chen committed
374
        self.res_layers = []
pangjm's avatar
pangjm committed
375
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
376
377
            stride = strides[i]
            dilation = dilations[i]
Kai Chen's avatar
Kai Chen committed
378
            dcn = self.dcn if self.stage_with_dcn[i] else None
Kai Chen's avatar
Kai Chen committed
379
380
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
381
                self.block,
Kai Chen's avatar
Kai Chen committed
382
383
384
385
386
387
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
388
                with_cp=with_cp,
389
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
390
                norm_cfg=norm_cfg,
Kai Chen's avatar
Kai Chen committed
391
                dcn=dcn)
pangjm's avatar
pangjm committed
392
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
393
            layer_name = 'layer{}'.format(i + 1)
394
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
395
396
            self.res_layers.append(layer_name)

ThangVu's avatar
ThangVu committed
397
398
        self._freeze_stages()

pangjm's avatar
pangjm committed
399
400
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
401

ThangVu's avatar
ThangVu committed
402
403
404
405
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

thangvu's avatar
thangvu committed
406
    def _make_stem_layer(self):
407
408
409
410
411
412
413
414
        self.conv1 = build_conv_layer(
            self.conv_cfg,
            3,
            64,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False)
Kai Chen's avatar
Kai Chen committed
415
        self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
ThangVu's avatar
ThangVu committed
416
        self.add_module(self.norm1_name, norm1)
thangvu's avatar
thangvu committed
417
418
419
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

ThangVu's avatar
ThangVu committed
420
    def _freeze_stages(self):
ThangVu's avatar
ThangVu committed
421
        if self.frozen_stages >= 0:
Kai Chen's avatar
Kai Chen committed
422
            self.norm1.eval()
ThangVu's avatar
ThangVu committed
423
            for m in [self.conv1, self.norm1]:
ThangVu's avatar
ThangVu committed
424
                for param in m.parameters():
thangvu's avatar
thangvu committed
425
426
                    param.requires_grad = False

ThangVu's avatar
ThangVu committed
427
428
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
Kai Chen's avatar
Kai Chen committed
429
            m.eval()
ThangVu's avatar
ThangVu committed
430
431
432
            for param in m.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
433
434
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
435
436
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
437
438
439
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
440
                    kaiming_init(m)
441
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
442
                    constant_init(m, 1)
443

Kai Chen's avatar
Kai Chen committed
444
445
446
447
448
449
            if self.dcn is not None:
                for m in self.modules():
                    if isinstance(m, Bottleneck) and hasattr(
                            m, 'conv2_offset'):
                        constant_init(m.conv2_offset, 0)

ThangVu's avatar
ThangVu committed
450
451
            if self.zero_init_residual:
                for m in self.modules():
ThangVu's avatar
ThangVu committed
452
453
454
455
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
Kai Chen's avatar
Kai Chen committed
456
457
458
459
460
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
ThangVu's avatar
ThangVu committed
461
        x = self.norm1(x)
Kai Chen's avatar
Kai Chen committed
462
463
464
465
466
467
468
469
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
myownskyW7's avatar
myownskyW7 committed
470
        return tuple(outs)
Kai Chen's avatar
Kai Chen committed
471
472
473

    def train(self, mode=True):
        super(ResNet, self).train(mode)
474
        self._freeze_stages()
thangvu's avatar
thangvu committed
475
        if mode and self.norm_eval:
ThangVu's avatar
ThangVu committed
476
            for m in self.modules():
thangvu's avatar
thangvu committed
477
                # trick: eval have effect on BatchNorm only
478
                if isinstance(m, _BatchNorm):
ThangVu's avatar
ThangVu committed
479
                    m.eval()