resnet.py 14.3 KB
Newer Older
1
2
import logging

Kai Chen's avatar
Kai Chen committed
3
4
import torch.nn as nn
import torch.utils.checkpoint as cp
Kai Chen's avatar
Kai Chen committed
5
6

from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
7
from mmcv.runner import load_checkpoint
Kai Chen's avatar
Kai Chen committed
8
9
10

from mmdet.ops import DeformConv, ModulatedDeformConv
from ..registry import BACKBONES
11
from ..utils import build_norm_layer
Kai Chen's avatar
Kai Chen committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34


def conv3x3(in_planes, out_planes, stride=1, dilation=1):
    "3x3 convolution with padding"
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        dilation=dilation,
        bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
35
                 style='pytorch',
36
                 with_cp=False,
ThangVu's avatar
ThangVu committed
37
                 normalize=dict(type='BN')):
Kai Chen's avatar
Kai Chen committed
38
        super(BasicBlock, self).__init__()
39

ThangVu's avatar
ThangVu committed
40
41
        self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
42
43

        self.conv1 = conv3x3(inplanes, planes, stride, dilation)
ThangVu's avatar
ThangVu committed
44
        self.add_module(self.norm1_name, norm1)
45
        self.conv2 = conv3x3(planes, planes)
ThangVu's avatar
ThangVu committed
46
        self.add_module(self.norm2_name, norm2)
47

Kai Chen's avatar
Kai Chen committed
48
49
50
51
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
52
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
53

ThangVu's avatar
ThangVu committed
54
55
56
57
58
59
60
61
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

Kai Chen's avatar
Kai Chen committed
62
    def forward(self, x):
pangjm's avatar
pangjm committed
63
        identity = x
Kai Chen's avatar
Kai Chen committed
64
65

        out = self.conv1(x)
ThangVu's avatar
ThangVu committed
66
        out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
67
68
69
        out = self.relu(out)

        out = self.conv2(out)
ThangVu's avatar
ThangVu committed
70
        out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
71
72

        if self.downsample is not None:
pangjm's avatar
pangjm committed
73
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
74

pangjm's avatar
pangjm committed
75
        out += identity
Kai Chen's avatar
Kai Chen committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
90
                 style='pytorch',
91
                 with_cp=False,
Kai Chen's avatar
Kai Chen committed
92
93
                 normalize=dict(type='BN'),
                 dcn=None):
pangjm's avatar
pangjm committed
94
        """Bottleneck block for ResNet.
95
96
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
97
98
        """
        super(Bottleneck, self).__init__()
99
        assert style in ['pytorch', 'caffe']
Kai Chen's avatar
Kai Chen committed
100
        assert dcn is None or isinstance(dcn, dict)
pangjm's avatar
pangjm committed
101
102
        self.inplanes = inplanes
        self.planes = planes
ThangVu's avatar
ThangVu committed
103
        self.normalize = normalize
Kai Chen's avatar
Kai Chen committed
104
105
        self.dcn = dcn
        self.with_dcn = dcn is not None
106
        if style == 'pytorch':
pangjm's avatar
pangjm committed
107
108
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
109
        else:
pangjm's avatar
pangjm committed
110
111
            self.conv1_stride = stride
            self.conv2_stride = 1
112
113
114

        self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
Kai Chen's avatar
Kai Chen committed
115
116
        self.norm3_name, norm3 = build_norm_layer(
            normalize, planes * self.expansion, postfix=3)
117

Kai Chen's avatar
Kai Chen committed
118
        self.conv1 = nn.Conv2d(
pangjm's avatar
pangjm committed
119
120
121
122
123
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
124
        self.add_module(self.norm1_name, norm1)
Kai Chen's avatar
Kai Chen committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        fallback_on_stride = False
        self.with_modulated_dcn = False
        if self.with_dcn:
            fallback_on_stride = dcn.get('fallback_on_stride', False)
            self.with_modulated_dcn = dcn.get('modulated', False)
        if not self.with_dcn or fallback_on_stride:
            self.conv2 = nn.Conv2d(
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                bias=False)
        else:
            deformable_groups = dcn.get('deformable_groups', 1)
            if not self.with_modulated_dcn:
                conv_op = DeformConv
                offset_channels = 18
            else:
                conv_op = ModulatedDeformConv
                offset_channels = 27
            self.conv2_offset = nn.Conv2d(
                planes,
                deformable_groups * offset_channels,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation)
            self.conv2 = conv_op(
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                deformable_groups=deformable_groups,
                bias=False)
ThangVu's avatar
ThangVu committed
163
        self.add_module(self.norm2_name, norm2)
Kai Chen's avatar
Kai Chen committed
164
165
        self.conv3 = nn.Conv2d(
            planes, planes * self.expansion, kernel_size=1, bias=False)
166
167
        self.add_module(self.norm3_name, norm3)

Kai Chen's avatar
Kai Chen committed
168
169
170
171
172
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
        self.with_cp = with_cp
173
        self.normalize = normalize
Kai Chen's avatar
Kai Chen committed
174

ThangVu's avatar
ThangVu committed
175
176
177
178
179
180
181
182
183
184
185
186
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    @property
    def norm3(self):
        return getattr(self, self.norm3_name)

Kai Chen's avatar
Kai Chen committed
187
188
189
    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
190
            identity = x
Kai Chen's avatar
Kai Chen committed
191
192

            out = self.conv1(x)
ThangVu's avatar
ThangVu committed
193
            out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
194
195
            out = self.relu(out)

Kai Chen's avatar
Kai Chen committed
196
197
198
199
200
201
202
203
204
205
            if not self.with_dcn:
                out = self.conv2(out)
            elif self.with_modulated_dcn:
                offset_mask = self.conv2_offset(out)
                offset = offset_mask[:, :18, :, :]
                mask = offset_mask[:, -9:, :, :].sigmoid()
                out = self.conv2(out, offset, mask)
            else:
                offset = self.conv2_offset(out)
                out = self.conv2(out, offset)
ThangVu's avatar
ThangVu committed
206
            out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
207
208
209
            out = self.relu(out)

            out = self.conv3(out)
ThangVu's avatar
ThangVu committed
210
            out = self.norm3(out)
Kai Chen's avatar
Kai Chen committed
211
212

            if self.downsample is not None:
pangjm's avatar
pangjm committed
213
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
214

pangjm's avatar
pangjm committed
215
            out += identity
Kai Chen's avatar
Kai Chen committed
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
235
                   style='pytorch',
236
                   with_cp=False,
Kai Chen's avatar
Kai Chen committed
237
238
                   normalize=dict(type='BN'),
                   dcn=None):
Kai Chen's avatar
Kai Chen committed
239
240
241
242
243
244
245
246
247
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
ThangVu's avatar
ThangVu committed
248
            build_norm_layer(normalize, planes * block.expansion)[1],
Kai Chen's avatar
Kai Chen committed
249
250
251
252
253
254
255
256
257
258
259
        )

    layers = []
    layers.append(
        block(
            inplanes,
            planes,
            stride,
            dilation,
            downsample,
            style=style,
260
            with_cp=with_cp,
Kai Chen's avatar
Kai Chen committed
261
262
            normalize=normalize,
            dcn=dcn))
Kai Chen's avatar
Kai Chen committed
263
264
265
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
Kai Chen's avatar
Kai Chen committed
266
267
268
269
270
271
272
273
274
            block(
                inplanes,
                planes,
                1,
                dilation,
                style=style,
                with_cp=with_cp,
                normalize=normalize,
                dcn=dcn))
Kai Chen's avatar
Kai Chen committed
275
276
277
278

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
279
@BACKBONES.register_module
Kai Chen's avatar
Kai Chen committed
280
281
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
282

Kai Chen's avatar
Kai Chen committed
283
284
285
286
287
288
289
290
291
292
293
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
            not freezing any parameters.
thangvu's avatar
thangvu committed
294
295
296
297
        normalize (dict): dictionary to construct and config norm layer.
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only.
Kai Chen's avatar
Kai Chen committed
298
299
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
thangvu's avatar
thangvu committed
300
301
        zero_init_residual (bool): whether to use zero init for last norm layer
            in resblocks to let them behave as identity.
Kai Chen's avatar
Kai Chen committed
302
    """
Kai Chen's avatar
Kai Chen committed
303

Kai Chen's avatar
Kai Chen committed
304
305
306
307
308
309
310
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
311
312

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
313
314
                 depth,
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
315
316
317
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
318
                 style='pytorch',
ThangVu's avatar
ThangVu committed
319
                 frozen_stages=-1,
Kai Chen's avatar
Kai Chen committed
320
                 normalize=dict(type='BN', frozen=False),
thangvu's avatar
thangvu committed
321
                 norm_eval=True,
Kai Chen's avatar
Kai Chen committed
322
323
                 dcn=None,
                 stage_with_dcn=(False, False, False, False),
ThangVu's avatar
ThangVu committed
324
325
                 with_cp=False,
                 zero_init_residual=True):
Kai Chen's avatar
Kai Chen committed
326
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
327
328
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
329
330
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
331
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
332
333
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
334
335
        assert len(strides) == len(dilations) == len(
            stage_with_dcn) == num_stages
Kai Chen's avatar
Kai Chen committed
336
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
337
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
338
        self.style = style
ThangVu's avatar
ThangVu committed
339
        self.frozen_stages = frozen_stages
340
        self.normalize = normalize
ThangVu's avatar
ThangVu committed
341
        self.with_cp = with_cp
thangvu's avatar
thangvu committed
342
        self.norm_eval = norm_eval
Kai Chen's avatar
Kai Chen committed
343
344
        self.dcn = dcn
        self.stage_with_dcn = stage_with_dcn
ThangVu's avatar
ThangVu committed
345
        self.zero_init_residual = zero_init_residual
pangjm's avatar
pangjm committed
346
347
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
348
        self.inplanes = 64
pangjm's avatar
pangjm committed
349

thangvu's avatar
thangvu committed
350
        self._make_stem_layer()
Kai Chen's avatar
Kai Chen committed
351

Kai Chen's avatar
Kai Chen committed
352
        self.res_layers = []
pangjm's avatar
pangjm committed
353
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
354
355
            stride = strides[i]
            dilation = dilations[i]
Kai Chen's avatar
Kai Chen committed
356
            dcn = self.dcn if self.stage_with_dcn[i] else None
Kai Chen's avatar
Kai Chen committed
357
358
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
359
                self.block,
Kai Chen's avatar
Kai Chen committed
360
361
362
363
364
365
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
366
                with_cp=with_cp,
Kai Chen's avatar
Kai Chen committed
367
368
                normalize=normalize,
                dcn=dcn)
pangjm's avatar
pangjm committed
369
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
370
            layer_name = 'layer{}'.format(i + 1)
371
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
372
373
            self.res_layers.append(layer_name)

ThangVu's avatar
ThangVu committed
374
375
        self._freeze_stages()

pangjm's avatar
pangjm committed
376
377
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
378

ThangVu's avatar
ThangVu committed
379
380
381
382
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

thangvu's avatar
thangvu committed
383
384
385
    def _make_stem_layer(self):
        self.conv1 = nn.Conv2d(
            3, 64, kernel_size=7, stride=2, padding=3, bias=False)
Kai Chen's avatar
Kai Chen committed
386
387
        self.norm1_name, norm1 = build_norm_layer(
            self.normalize, 64, postfix=1)
ThangVu's avatar
ThangVu committed
388
        self.add_module(self.norm1_name, norm1)
thangvu's avatar
thangvu committed
389
390
391
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

ThangVu's avatar
ThangVu committed
392
    def _freeze_stages(self):
ThangVu's avatar
ThangVu committed
393
        if self.frozen_stages >= 0:
ThangVu's avatar
ThangVu committed
394
            for m in [self.conv1, self.norm1]:
ThangVu's avatar
ThangVu committed
395
                for param in m.parameters():
thangvu's avatar
thangvu committed
396
397
                    param.requires_grad = False

ThangVu's avatar
ThangVu committed
398
399
400
401
402
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
            for param in m.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
403
404
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
405
406
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
407
408
409
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
410
                    kaiming_init(m)
ThangVu's avatar
minor  
ThangVu committed
411
                elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
412
                    constant_init(m, 1)
413

Kai Chen's avatar
Kai Chen committed
414
415
416
417
418
419
            if self.dcn is not None:
                for m in self.modules():
                    if isinstance(m, Bottleneck) and hasattr(
                            m, 'conv2_offset'):
                        constant_init(m.conv2_offset, 0)

ThangVu's avatar
ThangVu committed
420
421
            if self.zero_init_residual:
                for m in self.modules():
ThangVu's avatar
ThangVu committed
422
423
424
425
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
Kai Chen's avatar
Kai Chen committed
426
427
428
429
430
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
ThangVu's avatar
ThangVu committed
431
        x = self.norm1(x)
Kai Chen's avatar
Kai Chen committed
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
        if len(outs) == 1:
            return outs[0]
        else:
            return tuple(outs)

    def train(self, mode=True):
        super(ResNet, self).train(mode)
thangvu's avatar
thangvu committed
447
        if mode and self.norm_eval:
ThangVu's avatar
ThangVu committed
448
            for m in self.modules():
thangvu's avatar
thangvu committed
449
                # trick: eval have effect on BatchNorm only
ThangVu's avatar
ThangVu committed
450
451
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()