resnet.py 15 KB
Newer Older
1
2
import logging

Kai Chen's avatar
Kai Chen committed
3
4
import torch.nn as nn
import torch.utils.checkpoint as cp
5
from torch.nn.modules.batchnorm import _BatchNorm
Kai Chen's avatar
Kai Chen committed
6
7

from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
8
from mmcv.runner import load_checkpoint
Kai Chen's avatar
Kai Chen committed
9
10
11

from mmdet.ops import DeformConv, ModulatedDeformConv
from ..registry import BACKBONES
12
from ..utils import build_conv_layer, build_norm_layer
Kai Chen's avatar
Kai Chen committed
13
14
15
16
17
18
19
20
21
22
23


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
24
                 style='pytorch',
25
                 with_cp=False,
26
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
27
                 norm_cfg=dict(type='BN'),
pangjm's avatar
pangjm committed
28
                 dcn=None):
Kai Chen's avatar
Kai Chen committed
29
        super(BasicBlock, self).__init__()
pangjm's avatar
pangjm committed
30
        assert dcn is None, "Not implemented yet."
31

Kai Chen's avatar
Kai Chen committed
32
33
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
34

35
36
37
38
39
40
41
42
43
        self.conv1 = build_conv_layer(
            conv_cfg,
            inplanes,
            planes,
            3,
            stride=stride,
            padding=dilation,
            dilation=dilation,
            bias=False)
ThangVu's avatar
ThangVu committed
44
        self.add_module(self.norm1_name, norm1)
45
        self.conv2 = build_conv_layer(
46
            conv_cfg, planes, planes, 3, padding=1, bias=False)
ThangVu's avatar
ThangVu committed
47
        self.add_module(self.norm2_name, norm2)
48

Kai Chen's avatar
Kai Chen committed
49
50
51
52
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
53
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
54

ThangVu's avatar
ThangVu committed
55
56
57
58
59
60
61
62
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

Kai Chen's avatar
Kai Chen committed
63
    def forward(self, x):
pangjm's avatar
pangjm committed
64
        identity = x
Kai Chen's avatar
Kai Chen committed
65
66

        out = self.conv1(x)
ThangVu's avatar
ThangVu committed
67
        out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
68
69
70
        out = self.relu(out)

        out = self.conv2(out)
ThangVu's avatar
ThangVu committed
71
        out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
72
73

        if self.downsample is not None:
pangjm's avatar
pangjm committed
74
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
75

pangjm's avatar
pangjm committed
76
        out += identity
Kai Chen's avatar
Kai Chen committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
91
                 style='pytorch',
92
                 with_cp=False,
93
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
94
                 norm_cfg=dict(type='BN'),
Kai Chen's avatar
Kai Chen committed
95
                 dcn=None):
pangjm's avatar
pangjm committed
96
        """Bottleneck block for ResNet.
97
98
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
99
100
        """
        super(Bottleneck, self).__init__()
101
        assert style in ['pytorch', 'caffe']
Kai Chen's avatar
Kai Chen committed
102
        assert dcn is None or isinstance(dcn, dict)
pangjm's avatar
pangjm committed
103
104
        self.inplanes = inplanes
        self.planes = planes
Kai Chen's avatar
Kai Chen committed
105
106
107
108
        self.stride = stride
        self.dilation = dilation
        self.style = style
        self.with_cp = with_cp
109
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
110
        self.norm_cfg = norm_cfg
Kai Chen's avatar
Kai Chen committed
111
112
        self.dcn = dcn
        self.with_dcn = dcn is not None
Kai Chen's avatar
Kai Chen committed
113
        if self.style == 'pytorch':
pangjm's avatar
pangjm committed
114
115
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
116
        else:
pangjm's avatar
pangjm committed
117
118
            self.conv1_stride = stride
            self.conv2_stride = 1
119

Kai Chen's avatar
Kai Chen committed
120
121
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
Kai Chen's avatar
Kai Chen committed
122
        self.norm3_name, norm3 = build_norm_layer(
Kai Chen's avatar
Kai Chen committed
123
            norm_cfg, planes * self.expansion, postfix=3)
124

125
126
        self.conv1 = build_conv_layer(
            conv_cfg,
pangjm's avatar
pangjm committed
127
128
129
130
131
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
132
        self.add_module(self.norm1_name, norm1)
Kai Chen's avatar
Kai Chen committed
133
134
135
136
137
138
        fallback_on_stride = False
        self.with_modulated_dcn = False
        if self.with_dcn:
            fallback_on_stride = dcn.get('fallback_on_stride', False)
            self.with_modulated_dcn = dcn.get('modulated', False)
        if not self.with_dcn or fallback_on_stride:
139
140
            self.conv2 = build_conv_layer(
                conv_cfg,
Kai Chen's avatar
Kai Chen committed
141
142
143
144
145
146
147
148
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                bias=False)
        else:
149
            assert conv_cfg is None, 'conv_cfg must be None for DCN'
Kai Chen's avatar
Kai Chen committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
            deformable_groups = dcn.get('deformable_groups', 1)
            if not self.with_modulated_dcn:
                conv_op = DeformConv
                offset_channels = 18
            else:
                conv_op = ModulatedDeformConv
                offset_channels = 27
            self.conv2_offset = nn.Conv2d(
                planes,
                deformable_groups * offset_channels,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation)
            self.conv2 = conv_op(
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                deformable_groups=deformable_groups,
                bias=False)
ThangVu's avatar
ThangVu committed
173
        self.add_module(self.norm2_name, norm2)
174
175
176
177
178
179
        self.conv3 = build_conv_layer(
            conv_cfg,
            planes,
            planes * self.expansion,
            kernel_size=1,
            bias=False)
180
181
        self.add_module(self.norm3_name, norm3)

Kai Chen's avatar
Kai Chen committed
182
        self.relu = nn.ReLU(inplace=True)
183
        self.downsample = downsample
Kai Chen's avatar
Kai Chen committed
184

ThangVu's avatar
ThangVu committed
185
186
187
188
189
190
191
192
193
194
195
196
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    @property
    def norm3(self):
        return getattr(self, self.norm3_name)

Kai Chen's avatar
Kai Chen committed
197
198
199
    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
200
            identity = x
Kai Chen's avatar
Kai Chen committed
201
202

            out = self.conv1(x)
ThangVu's avatar
ThangVu committed
203
            out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
204
205
            out = self.relu(out)

Kai Chen's avatar
Kai Chen committed
206
207
208
209
210
211
212
213
214
215
            if not self.with_dcn:
                out = self.conv2(out)
            elif self.with_modulated_dcn:
                offset_mask = self.conv2_offset(out)
                offset = offset_mask[:, :18, :, :]
                mask = offset_mask[:, -9:, :, :].sigmoid()
                out = self.conv2(out, offset, mask)
            else:
                offset = self.conv2_offset(out)
                out = self.conv2(out, offset)
ThangVu's avatar
ThangVu committed
216
            out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
217
218
219
            out = self.relu(out)

            out = self.conv3(out)
ThangVu's avatar
ThangVu committed
220
            out = self.norm3(out)
Kai Chen's avatar
Kai Chen committed
221
222

            if self.downsample is not None:
pangjm's avatar
pangjm committed
223
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
224

pangjm's avatar
pangjm committed
225
            out += identity
Kai Chen's avatar
Kai Chen committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
245
                   style='pytorch',
246
                   with_cp=False,
247
                   conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
248
                   norm_cfg=dict(type='BN'),
Kai Chen's avatar
Kai Chen committed
249
                   dcn=None):
Kai Chen's avatar
Kai Chen committed
250
251
252
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
253
254
            build_conv_layer(
                conv_cfg,
Kai Chen's avatar
Kai Chen committed
255
256
257
258
259
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
Kai Chen's avatar
Kai Chen committed
260
            build_norm_layer(norm_cfg, planes * block.expansion)[1],
Kai Chen's avatar
Kai Chen committed
261
262
263
264
265
266
267
268
269
270
271
        )

    layers = []
    layers.append(
        block(
            inplanes,
            planes,
            stride,
            dilation,
            downsample,
            style=style,
272
            with_cp=with_cp,
273
            conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
274
            norm_cfg=norm_cfg,
Kai Chen's avatar
Kai Chen committed
275
            dcn=dcn))
Kai Chen's avatar
Kai Chen committed
276
277
278
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
Kai Chen's avatar
Kai Chen committed
279
280
281
282
283
284
285
            block(
                inplanes,
                planes,
                1,
                dilation,
                style=style,
                with_cp=with_cp,
286
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
287
                norm_cfg=norm_cfg,
Kai Chen's avatar
Kai Chen committed
288
                dcn=dcn))
Kai Chen's avatar
Kai Chen committed
289
290
291
292

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
293
@BACKBONES.register_module
Kai Chen's avatar
Kai Chen committed
294
295
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
296

Kai Chen's avatar
Kai Chen committed
297
298
299
300
301
302
303
304
305
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
306
307
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any parameters.
Kai Chen's avatar
Kai Chen committed
308
        norm_cfg (dict): dictionary to construct and config norm layer.
thangvu's avatar
thangvu committed
309
310
311
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only.
Kai Chen's avatar
Kai Chen committed
312
313
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
thangvu's avatar
thangvu committed
314
315
        zero_init_residual (bool): whether to use zero init for last norm layer
            in resblocks to let them behave as identity.
Kai Chen's avatar
Kai Chen committed
316
    """
Kai Chen's avatar
Kai Chen committed
317

Kai Chen's avatar
Kai Chen committed
318
319
320
321
322
323
324
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
325
326

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
327
328
                 depth,
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
329
330
331
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
332
                 style='pytorch',
ThangVu's avatar
ThangVu committed
333
                 frozen_stages=-1,
334
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
335
                 norm_cfg=dict(type='BN', requires_grad=True),
thangvu's avatar
thangvu committed
336
                 norm_eval=True,
Kai Chen's avatar
Kai Chen committed
337
338
                 dcn=None,
                 stage_with_dcn=(False, False, False, False),
ThangVu's avatar
ThangVu committed
339
340
                 with_cp=False,
                 zero_init_residual=True):
Kai Chen's avatar
Kai Chen committed
341
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
342
343
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
344
345
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
346
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
347
348
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
349
        assert len(strides) == len(dilations) == num_stages
Kai Chen's avatar
Kai Chen committed
350
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
351
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
352
        self.style = style
ThangVu's avatar
ThangVu committed
353
        self.frozen_stages = frozen_stages
354
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
355
        self.norm_cfg = norm_cfg
ThangVu's avatar
ThangVu committed
356
        self.with_cp = with_cp
thangvu's avatar
thangvu committed
357
        self.norm_eval = norm_eval
Kai Chen's avatar
Kai Chen committed
358
359
        self.dcn = dcn
        self.stage_with_dcn = stage_with_dcn
Kai Chen's avatar
Kai Chen committed
360
361
        if dcn is not None:
            assert len(stage_with_dcn) == num_stages
ThangVu's avatar
ThangVu committed
362
        self.zero_init_residual = zero_init_residual
pangjm's avatar
pangjm committed
363
364
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
365
        self.inplanes = 64
pangjm's avatar
pangjm committed
366

thangvu's avatar
thangvu committed
367
        self._make_stem_layer()
Kai Chen's avatar
Kai Chen committed
368

Kai Chen's avatar
Kai Chen committed
369
        self.res_layers = []
pangjm's avatar
pangjm committed
370
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
371
372
            stride = strides[i]
            dilation = dilations[i]
Kai Chen's avatar
Kai Chen committed
373
            dcn = self.dcn if self.stage_with_dcn[i] else None
Kai Chen's avatar
Kai Chen committed
374
375
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
376
                self.block,
Kai Chen's avatar
Kai Chen committed
377
378
379
380
381
382
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
383
                with_cp=with_cp,
384
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
385
                norm_cfg=norm_cfg,
Kai Chen's avatar
Kai Chen committed
386
                dcn=dcn)
pangjm's avatar
pangjm committed
387
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
388
            layer_name = 'layer{}'.format(i + 1)
389
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
390
391
            self.res_layers.append(layer_name)

ThangVu's avatar
ThangVu committed
392
393
        self._freeze_stages()

pangjm's avatar
pangjm committed
394
395
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
396

ThangVu's avatar
ThangVu committed
397
398
399
400
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

thangvu's avatar
thangvu committed
401
    def _make_stem_layer(self):
402
403
404
405
406
407
408
409
        self.conv1 = build_conv_layer(
            self.conv_cfg,
            3,
            64,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False)
Kai Chen's avatar
Kai Chen committed
410
        self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
ThangVu's avatar
ThangVu committed
411
        self.add_module(self.norm1_name, norm1)
thangvu's avatar
thangvu committed
412
413
414
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

ThangVu's avatar
ThangVu committed
415
    def _freeze_stages(self):
ThangVu's avatar
ThangVu committed
416
        if self.frozen_stages >= 0:
Kai Chen's avatar
Kai Chen committed
417
            self.norm1.eval()
ThangVu's avatar
ThangVu committed
418
            for m in [self.conv1, self.norm1]:
ThangVu's avatar
ThangVu committed
419
                for param in m.parameters():
thangvu's avatar
thangvu committed
420
421
                    param.requires_grad = False

ThangVu's avatar
ThangVu committed
422
423
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
Kai Chen's avatar
Kai Chen committed
424
            m.eval()
ThangVu's avatar
ThangVu committed
425
426
427
            for param in m.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
428
429
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
430
431
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
432
433
434
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
435
                    kaiming_init(m)
436
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
437
                    constant_init(m, 1)
438

Kai Chen's avatar
Kai Chen committed
439
440
441
442
443
444
            if self.dcn is not None:
                for m in self.modules():
                    if isinstance(m, Bottleneck) and hasattr(
                            m, 'conv2_offset'):
                        constant_init(m.conv2_offset, 0)

ThangVu's avatar
ThangVu committed
445
446
            if self.zero_init_residual:
                for m in self.modules():
ThangVu's avatar
ThangVu committed
447
448
449
450
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
Kai Chen's avatar
Kai Chen committed
451
452
453
454
455
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
ThangVu's avatar
ThangVu committed
456
        x = self.norm1(x)
Kai Chen's avatar
Kai Chen committed
457
458
459
460
461
462
463
464
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
myownskyW7's avatar
myownskyW7 committed
465
        return tuple(outs)
Kai Chen's avatar
Kai Chen committed
466
467
468

    def train(self, mode=True):
        super(ResNet, self).train(mode)
469
        self._freeze_stages()
thangvu's avatar
thangvu committed
470
        if mode and self.norm_eval:
ThangVu's avatar
ThangVu committed
471
            for m in self.modules():
thangvu's avatar
thangvu committed
472
                # trick: eval have effect on BatchNorm only
473
                if isinstance(m, _BatchNorm):
ThangVu's avatar
ThangVu committed
474
                    m.eval()