resnet.py 15.9 KB
Newer Older
1
2
import logging

Kai Chen's avatar
Kai Chen committed
3
4
import torch.nn as nn
import torch.utils.checkpoint as cp
5
from torch.nn.modules.batchnorm import _BatchNorm
Kai Chen's avatar
Kai Chen committed
6
7

from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
8
from mmcv.runner import load_checkpoint
Kai Chen's avatar
Kai Chen committed
9

10
from mmdet.ops import DeformConv, ModulatedDeformConv, ContextBlock
Kai Chen's avatar
Kai Chen committed
11
from ..registry import BACKBONES
12
from ..utils import build_conv_layer, build_norm_layer
Kai Chen's avatar
Kai Chen committed
13
14
15
16
17
18
19
20
21
22
23


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
24
                 style='pytorch',
25
                 with_cp=False,
26
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
27
                 norm_cfg=dict(type='BN'),
28
29
                 dcn=None,
                 gcb=None):
Kai Chen's avatar
Kai Chen committed
30
        super(BasicBlock, self).__init__()
pangjm's avatar
pangjm committed
31
        assert dcn is None, "Not implemented yet."
32
        assert gcb is None, "Not implemented yet."
33

Kai Chen's avatar
Kai Chen committed
34
35
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
36

37
38
39
40
41
42
43
44
45
        self.conv1 = build_conv_layer(
            conv_cfg,
            inplanes,
            planes,
            3,
            stride=stride,
            padding=dilation,
            dilation=dilation,
            bias=False)
ThangVu's avatar
ThangVu committed
46
        self.add_module(self.norm1_name, norm1)
47
        self.conv2 = build_conv_layer(
48
            conv_cfg, planes, planes, 3, padding=1, bias=False)
ThangVu's avatar
ThangVu committed
49
        self.add_module(self.norm2_name, norm2)
50

Kai Chen's avatar
Kai Chen committed
51
52
53
54
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
55
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
56

ThangVu's avatar
ThangVu committed
57
58
59
60
61
62
63
64
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

Kai Chen's avatar
Kai Chen committed
65
    def forward(self, x):
pangjm's avatar
pangjm committed
66
        identity = x
Kai Chen's avatar
Kai Chen committed
67
68

        out = self.conv1(x)
ThangVu's avatar
ThangVu committed
69
        out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
70
71
72
        out = self.relu(out)

        out = self.conv2(out)
ThangVu's avatar
ThangVu committed
73
        out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
74
75

        if self.downsample is not None:
pangjm's avatar
pangjm committed
76
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
77

pangjm's avatar
pangjm committed
78
        out += identity
Kai Chen's avatar
Kai Chen committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
93
                 style='pytorch',
94
                 with_cp=False,
95
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
96
                 norm_cfg=dict(type='BN'),
97
98
                 dcn=None,
                 gcb=None):
pangjm's avatar
pangjm committed
99
        """Bottleneck block for ResNet.
100
101
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
102
103
        """
        super(Bottleneck, self).__init__()
104
        assert style in ['pytorch', 'caffe']
Kai Chen's avatar
Kai Chen committed
105
        assert dcn is None or isinstance(dcn, dict)
106
        assert gcb is None or isinstance(gcb, dict)
pangjm's avatar
pangjm committed
107
108
        self.inplanes = inplanes
        self.planes = planes
Kai Chen's avatar
Kai Chen committed
109
110
111
112
        self.stride = stride
        self.dilation = dilation
        self.style = style
        self.with_cp = with_cp
113
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
114
        self.norm_cfg = norm_cfg
Kai Chen's avatar
Kai Chen committed
115
116
        self.dcn = dcn
        self.with_dcn = dcn is not None
117
118
        self.gcb = gcb
        self.with_gcb = gcb is not None
Kai Chen's avatar
Kai Chen committed
119
        if self.style == 'pytorch':
pangjm's avatar
pangjm committed
120
121
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
122
        else:
pangjm's avatar
pangjm committed
123
124
            self.conv1_stride = stride
            self.conv2_stride = 1
125

Kai Chen's avatar
Kai Chen committed
126
127
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
Kai Chen's avatar
Kai Chen committed
128
        self.norm3_name, norm3 = build_norm_layer(
Kai Chen's avatar
Kai Chen committed
129
            norm_cfg, planes * self.expansion, postfix=3)
130

131
132
        self.conv1 = build_conv_layer(
            conv_cfg,
pangjm's avatar
pangjm committed
133
134
135
136
137
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
138
        self.add_module(self.norm1_name, norm1)
Kai Chen's avatar
Kai Chen committed
139
140
141
142
143
144
        fallback_on_stride = False
        self.with_modulated_dcn = False
        if self.with_dcn:
            fallback_on_stride = dcn.get('fallback_on_stride', False)
            self.with_modulated_dcn = dcn.get('modulated', False)
        if not self.with_dcn or fallback_on_stride:
145
146
            self.conv2 = build_conv_layer(
                conv_cfg,
Kai Chen's avatar
Kai Chen committed
147
148
149
150
151
152
153
154
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                bias=False)
        else:
155
            assert conv_cfg is None, 'conv_cfg must be None for DCN'
Kai Chen's avatar
Kai Chen committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
            deformable_groups = dcn.get('deformable_groups', 1)
            if not self.with_modulated_dcn:
                conv_op = DeformConv
                offset_channels = 18
            else:
                conv_op = ModulatedDeformConv
                offset_channels = 27
            self.conv2_offset = nn.Conv2d(
                planes,
                deformable_groups * offset_channels,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation)
            self.conv2 = conv_op(
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                deformable_groups=deformable_groups,
                bias=False)
ThangVu's avatar
ThangVu committed
179
        self.add_module(self.norm2_name, norm2)
180
181
182
183
184
185
        self.conv3 = build_conv_layer(
            conv_cfg,
            planes,
            planes * self.expansion,
            kernel_size=1,
            bias=False)
186
187
        self.add_module(self.norm3_name, norm3)

Kai Chen's avatar
Kai Chen committed
188
        self.relu = nn.ReLU(inplace=True)
189
        self.downsample = downsample
190
191
192
193
194
195
        if self.with_gcb:
            gcb_inplanes = planes * self.expansion
            self.context_block = ContextBlock(
                inplanes=gcb_inplanes,
                **gcb
            )
Kai Chen's avatar
Kai Chen committed
196

ThangVu's avatar
ThangVu committed
197
198
199
200
201
202
203
204
205
206
207
208
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    @property
    def norm3(self):
        return getattr(self, self.norm3_name)

Kai Chen's avatar
Kai Chen committed
209
210
211
    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
212
            identity = x
Kai Chen's avatar
Kai Chen committed
213
214

            out = self.conv1(x)
ThangVu's avatar
ThangVu committed
215
            out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
216
217
            out = self.relu(out)

Kai Chen's avatar
Kai Chen committed
218
219
220
221
222
223
224
225
226
227
            if not self.with_dcn:
                out = self.conv2(out)
            elif self.with_modulated_dcn:
                offset_mask = self.conv2_offset(out)
                offset = offset_mask[:, :18, :, :]
                mask = offset_mask[:, -9:, :, :].sigmoid()
                out = self.conv2(out, offset, mask)
            else:
                offset = self.conv2_offset(out)
                out = self.conv2(out, offset)
ThangVu's avatar
ThangVu committed
228
            out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
229
230
231
            out = self.relu(out)

            out = self.conv3(out)
ThangVu's avatar
ThangVu committed
232
            out = self.norm3(out)
Kai Chen's avatar
Kai Chen committed
233

234
235
236
            if self.with_gcb:
                out = self.context_block(out)

Kai Chen's avatar
Kai Chen committed
237
            if self.downsample is not None:
pangjm's avatar
pangjm committed
238
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
239

pangjm's avatar
pangjm committed
240
            out += identity
Kai Chen's avatar
Kai Chen committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
260
                   style='pytorch',
261
                   with_cp=False,
262
                   conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
263
                   norm_cfg=dict(type='BN'),
264
265
                   dcn=None,
                   gcb=None):
Kai Chen's avatar
Kai Chen committed
266
267
268
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
269
270
            build_conv_layer(
                conv_cfg,
Kai Chen's avatar
Kai Chen committed
271
272
273
274
275
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
Kai Chen's avatar
Kai Chen committed
276
            build_norm_layer(norm_cfg, planes * block.expansion)[1],
Kai Chen's avatar
Kai Chen committed
277
278
279
280
281
282
283
284
285
286
287
        )

    layers = []
    layers.append(
        block(
            inplanes,
            planes,
            stride,
            dilation,
            downsample,
            style=style,
288
            with_cp=with_cp,
289
            conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
290
            norm_cfg=norm_cfg,
291
292
            dcn=dcn,
            gcb=gcb))
Kai Chen's avatar
Kai Chen committed
293
294
295
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
Kai Chen's avatar
Kai Chen committed
296
297
298
299
300
301
302
            block(
                inplanes,
                planes,
                1,
                dilation,
                style=style,
                with_cp=with_cp,
303
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
304
                norm_cfg=norm_cfg,
305
306
                dcn=dcn,
                gcb=gcb))
Kai Chen's avatar
Kai Chen committed
307
308
309
310

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
311
@BACKBONES.register_module
Kai Chen's avatar
Kai Chen committed
312
313
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
314

Kai Chen's avatar
Kai Chen committed
315
316
317
318
319
320
321
322
323
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
324
325
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any parameters.
Kai Chen's avatar
Kai Chen committed
326
        norm_cfg (dict): dictionary to construct and config norm layer.
thangvu's avatar
thangvu committed
327
328
329
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only.
Kai Chen's avatar
Kai Chen committed
330
331
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
thangvu's avatar
thangvu committed
332
333
        zero_init_residual (bool): whether to use zero init for last norm layer
            in resblocks to let them behave as identity.
Kai Chen's avatar
Kai Chen committed
334
    """
Kai Chen's avatar
Kai Chen committed
335

Kai Chen's avatar
Kai Chen committed
336
337
338
339
340
341
342
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
343
344

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
345
346
                 depth,
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
347
348
349
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
350
                 style='pytorch',
ThangVu's avatar
ThangVu committed
351
                 frozen_stages=-1,
352
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
353
                 norm_cfg=dict(type='BN', requires_grad=True),
thangvu's avatar
thangvu committed
354
                 norm_eval=True,
Kai Chen's avatar
Kai Chen committed
355
356
                 dcn=None,
                 stage_with_dcn=(False, False, False, False),
357
358
                 gcb=None,
                 stage_with_gcb=(False, False, False, False),
ThangVu's avatar
ThangVu committed
359
360
                 with_cp=False,
                 zero_init_residual=True):
Kai Chen's avatar
Kai Chen committed
361
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
362
363
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
364
365
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
366
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
367
368
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
369
        assert len(strides) == len(dilations) == num_stages
Kai Chen's avatar
Kai Chen committed
370
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
371
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
372
        self.style = style
ThangVu's avatar
ThangVu committed
373
        self.frozen_stages = frozen_stages
374
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
375
        self.norm_cfg = norm_cfg
ThangVu's avatar
ThangVu committed
376
        self.with_cp = with_cp
thangvu's avatar
thangvu committed
377
        self.norm_eval = norm_eval
Kai Chen's avatar
Kai Chen committed
378
379
        self.dcn = dcn
        self.stage_with_dcn = stage_with_dcn
Kai Chen's avatar
Kai Chen committed
380
381
        if dcn is not None:
            assert len(stage_with_dcn) == num_stages
382
383
384
385
        self.gcb = gcb
        self.stage_with_gcb = stage_with_gcb
        if gcb is not None:
            assert len(stage_with_gcb) == num_stages
ThangVu's avatar
ThangVu committed
386
        self.zero_init_residual = zero_init_residual
pangjm's avatar
pangjm committed
387
388
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
389
        self.inplanes = 64
pangjm's avatar
pangjm committed
390

thangvu's avatar
thangvu committed
391
        self._make_stem_layer()
Kai Chen's avatar
Kai Chen committed
392

Kai Chen's avatar
Kai Chen committed
393
        self.res_layers = []
pangjm's avatar
pangjm committed
394
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
395
396
            stride = strides[i]
            dilation = dilations[i]
Kai Chen's avatar
Kai Chen committed
397
            dcn = self.dcn if self.stage_with_dcn[i] else None
398
            gcb = self.gcb if self.stage_with_gcb[i] else None
Kai Chen's avatar
Kai Chen committed
399
400
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
401
                self.block,
Kai Chen's avatar
Kai Chen committed
402
403
404
405
406
407
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
408
                with_cp=with_cp,
409
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
410
                norm_cfg=norm_cfg,
411
412
                dcn=dcn,
                gcb=gcb)
pangjm's avatar
pangjm committed
413
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
414
            layer_name = 'layer{}'.format(i + 1)
415
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
416
417
            self.res_layers.append(layer_name)

ThangVu's avatar
ThangVu committed
418
419
        self._freeze_stages()

pangjm's avatar
pangjm committed
420
421
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
422

ThangVu's avatar
ThangVu committed
423
424
425
426
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

thangvu's avatar
thangvu committed
427
    def _make_stem_layer(self):
428
429
430
431
432
433
434
435
        self.conv1 = build_conv_layer(
            self.conv_cfg,
            3,
            64,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False)
Kai Chen's avatar
Kai Chen committed
436
        self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
ThangVu's avatar
ThangVu committed
437
        self.add_module(self.norm1_name, norm1)
thangvu's avatar
thangvu committed
438
439
440
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

ThangVu's avatar
ThangVu committed
441
    def _freeze_stages(self):
ThangVu's avatar
ThangVu committed
442
        if self.frozen_stages >= 0:
Kai Chen's avatar
Kai Chen committed
443
            self.norm1.eval()
ThangVu's avatar
ThangVu committed
444
            for m in [self.conv1, self.norm1]:
ThangVu's avatar
ThangVu committed
445
                for param in m.parameters():
thangvu's avatar
thangvu committed
446
447
                    param.requires_grad = False

ThangVu's avatar
ThangVu committed
448
449
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
Kai Chen's avatar
Kai Chen committed
450
            m.eval()
ThangVu's avatar
ThangVu committed
451
452
453
            for param in m.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
454
455
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
456
457
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
458
459
460
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
461
                    kaiming_init(m)
462
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
463
                    constant_init(m, 1)
464

Kai Chen's avatar
Kai Chen committed
465
466
467
468
469
470
            if self.dcn is not None:
                for m in self.modules():
                    if isinstance(m, Bottleneck) and hasattr(
                            m, 'conv2_offset'):
                        constant_init(m.conv2_offset, 0)

ThangVu's avatar
ThangVu committed
471
472
            if self.zero_init_residual:
                for m in self.modules():
ThangVu's avatar
ThangVu committed
473
474
475
476
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
Kai Chen's avatar
Kai Chen committed
477
478
479
480
481
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
ThangVu's avatar
ThangVu committed
482
        x = self.norm1(x)
Kai Chen's avatar
Kai Chen committed
483
484
485
486
487
488
489
490
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
myownskyW7's avatar
myownskyW7 committed
491
        return tuple(outs)
Kai Chen's avatar
Kai Chen committed
492
493
494

    def train(self, mode=True):
        super(ResNet, self).train(mode)
495
        self._freeze_stages()
thangvu's avatar
thangvu committed
496
        if mode and self.norm_eval:
ThangVu's avatar
ThangVu committed
497
            for m in self.modules():
thangvu's avatar
thangvu committed
498
                # trick: eval have effect on BatchNorm only
499
                if isinstance(m, _BatchNorm):
ThangVu's avatar
ThangVu committed
500
                    m.eval()