resnet.py 14.3 KB
Newer Older
1
2
import logging

Kai Chen's avatar
Kai Chen committed
3
4
import torch.nn as nn
import torch.utils.checkpoint as cp
Kai Chen's avatar
Kai Chen committed
5
6

from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
7
from mmcv.runner import load_checkpoint
Kai Chen's avatar
Kai Chen committed
8
9
10

from mmdet.ops import DeformConv, ModulatedDeformConv
from ..registry import BACKBONES
11
from ..utils import build_norm_layer
Kai Chen's avatar
Kai Chen committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34


def conv3x3(in_planes, out_planes, stride=1, dilation=1):
    "3x3 convolution with padding"
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        dilation=dilation,
        bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
35
                 style='pytorch',
36
                 with_cp=False,
pangjm's avatar
pangjm committed
37
38
                 normalize=dict(type='BN'),
                 dcn=None):
Kai Chen's avatar
Kai Chen committed
39
        super(BasicBlock, self).__init__()
pangjm's avatar
pangjm committed
40
        assert dcn is None, "Not implemented yet."
41

ThangVu's avatar
ThangVu committed
42
43
        self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
44
45

        self.conv1 = conv3x3(inplanes, planes, stride, dilation)
ThangVu's avatar
ThangVu committed
46
        self.add_module(self.norm1_name, norm1)
47
        self.conv2 = conv3x3(planes, planes)
ThangVu's avatar
ThangVu committed
48
        self.add_module(self.norm2_name, norm2)
49

Kai Chen's avatar
Kai Chen committed
50
51
52
53
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
54
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
55

ThangVu's avatar
ThangVu committed
56
57
58
59
60
61
62
63
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

Kai Chen's avatar
Kai Chen committed
64
    def forward(self, x):
pangjm's avatar
pangjm committed
65
        identity = x
Kai Chen's avatar
Kai Chen committed
66
67

        out = self.conv1(x)
ThangVu's avatar
ThangVu committed
68
        out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
69
70
71
        out = self.relu(out)

        out = self.conv2(out)
ThangVu's avatar
ThangVu committed
72
        out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
73
74

        if self.downsample is not None:
pangjm's avatar
pangjm committed
75
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
76

pangjm's avatar
pangjm committed
77
        out += identity
Kai Chen's avatar
Kai Chen committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
92
                 style='pytorch',
93
                 with_cp=False,
Kai Chen's avatar
Kai Chen committed
94
95
                 normalize=dict(type='BN'),
                 dcn=None):
pangjm's avatar
pangjm committed
96
        """Bottleneck block for ResNet.
97
98
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
99
100
        """
        super(Bottleneck, self).__init__()
101
        assert style in ['pytorch', 'caffe']
Kai Chen's avatar
Kai Chen committed
102
        assert dcn is None or isinstance(dcn, dict)
pangjm's avatar
pangjm committed
103
104
        self.inplanes = inplanes
        self.planes = planes
ThangVu's avatar
ThangVu committed
105
        self.normalize = normalize
Kai Chen's avatar
Kai Chen committed
106
107
        self.dcn = dcn
        self.with_dcn = dcn is not None
108
        if style == 'pytorch':
pangjm's avatar
pangjm committed
109
110
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
111
        else:
pangjm's avatar
pangjm committed
112
113
            self.conv1_stride = stride
            self.conv2_stride = 1
114
115
116

        self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
Kai Chen's avatar
Kai Chen committed
117
118
        self.norm3_name, norm3 = build_norm_layer(
            normalize, planes * self.expansion, postfix=3)
119

Kai Chen's avatar
Kai Chen committed
120
        self.conv1 = nn.Conv2d(
pangjm's avatar
pangjm committed
121
122
123
124
125
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
126
        self.add_module(self.norm1_name, norm1)
Kai Chen's avatar
Kai Chen committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        fallback_on_stride = False
        self.with_modulated_dcn = False
        if self.with_dcn:
            fallback_on_stride = dcn.get('fallback_on_stride', False)
            self.with_modulated_dcn = dcn.get('modulated', False)
        if not self.with_dcn or fallback_on_stride:
            self.conv2 = nn.Conv2d(
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                bias=False)
        else:
            deformable_groups = dcn.get('deformable_groups', 1)
            if not self.with_modulated_dcn:
                conv_op = DeformConv
                offset_channels = 18
            else:
                conv_op = ModulatedDeformConv
                offset_channels = 27
            self.conv2_offset = nn.Conv2d(
                planes,
                deformable_groups * offset_channels,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation)
            self.conv2 = conv_op(
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                deformable_groups=deformable_groups,
                bias=False)
ThangVu's avatar
ThangVu committed
165
        self.add_module(self.norm2_name, norm2)
Kai Chen's avatar
Kai Chen committed
166
167
        self.conv3 = nn.Conv2d(
            planes, planes * self.expansion, kernel_size=1, bias=False)
168
169
        self.add_module(self.norm3_name, norm3)

Kai Chen's avatar
Kai Chen committed
170
171
172
173
174
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
        self.with_cp = with_cp
175
        self.normalize = normalize
Kai Chen's avatar
Kai Chen committed
176

ThangVu's avatar
ThangVu committed
177
178
179
180
181
182
183
184
185
186
187
188
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    @property
    def norm3(self):
        return getattr(self, self.norm3_name)

Kai Chen's avatar
Kai Chen committed
189
190
191
    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
192
            identity = x
Kai Chen's avatar
Kai Chen committed
193
194

            out = self.conv1(x)
ThangVu's avatar
ThangVu committed
195
            out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
196
197
            out = self.relu(out)

Kai Chen's avatar
Kai Chen committed
198
199
200
201
202
203
204
205
206
207
            if not self.with_dcn:
                out = self.conv2(out)
            elif self.with_modulated_dcn:
                offset_mask = self.conv2_offset(out)
                offset = offset_mask[:, :18, :, :]
                mask = offset_mask[:, -9:, :, :].sigmoid()
                out = self.conv2(out, offset, mask)
            else:
                offset = self.conv2_offset(out)
                out = self.conv2(out, offset)
ThangVu's avatar
ThangVu committed
208
            out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
209
210
211
            out = self.relu(out)

            out = self.conv3(out)
ThangVu's avatar
ThangVu committed
212
            out = self.norm3(out)
Kai Chen's avatar
Kai Chen committed
213
214

            if self.downsample is not None:
pangjm's avatar
pangjm committed
215
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
216

pangjm's avatar
pangjm committed
217
            out += identity
Kai Chen's avatar
Kai Chen committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
237
                   style='pytorch',
238
                   with_cp=False,
Kai Chen's avatar
Kai Chen committed
239
240
                   normalize=dict(type='BN'),
                   dcn=None):
Kai Chen's avatar
Kai Chen committed
241
242
243
244
245
246
247
248
249
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
ThangVu's avatar
ThangVu committed
250
            build_norm_layer(normalize, planes * block.expansion)[1],
Kai Chen's avatar
Kai Chen committed
251
252
253
254
255
256
257
258
259
260
261
        )

    layers = []
    layers.append(
        block(
            inplanes,
            planes,
            stride,
            dilation,
            downsample,
            style=style,
262
            with_cp=with_cp,
Kai Chen's avatar
Kai Chen committed
263
264
            normalize=normalize,
            dcn=dcn))
Kai Chen's avatar
Kai Chen committed
265
266
267
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
Kai Chen's avatar
Kai Chen committed
268
269
270
271
272
273
274
275
276
            block(
                inplanes,
                planes,
                1,
                dilation,
                style=style,
                with_cp=with_cp,
                normalize=normalize,
                dcn=dcn))
Kai Chen's avatar
Kai Chen committed
277
278
279
280

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
281
@BACKBONES.register_module
Kai Chen's avatar
Kai Chen committed
282
283
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
284

Kai Chen's avatar
Kai Chen committed
285
286
287
288
289
290
291
292
293
294
295
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
            not freezing any parameters.
thangvu's avatar
thangvu committed
296
297
298
299
        normalize (dict): dictionary to construct and config norm layer.
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only.
Kai Chen's avatar
Kai Chen committed
300
301
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
thangvu's avatar
thangvu committed
302
303
        zero_init_residual (bool): whether to use zero init for last norm layer
            in resblocks to let them behave as identity.
Kai Chen's avatar
Kai Chen committed
304
    """
Kai Chen's avatar
Kai Chen committed
305

Kai Chen's avatar
Kai Chen committed
306
307
308
309
310
311
312
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
313
314

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
315
316
                 depth,
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
317
318
319
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
320
                 style='pytorch',
ThangVu's avatar
ThangVu committed
321
                 frozen_stages=-1,
Kai Chen's avatar
Kai Chen committed
322
                 normalize=dict(type='BN', frozen=False),
thangvu's avatar
thangvu committed
323
                 norm_eval=True,
Kai Chen's avatar
Kai Chen committed
324
325
                 dcn=None,
                 stage_with_dcn=(False, False, False, False),
ThangVu's avatar
ThangVu committed
326
327
                 with_cp=False,
                 zero_init_residual=True):
Kai Chen's avatar
Kai Chen committed
328
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
329
330
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
331
332
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
333
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
334
335
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
336
        assert len(strides) == len(dilations) == num_stages
Kai Chen's avatar
Kai Chen committed
337
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
338
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
339
        self.style = style
ThangVu's avatar
ThangVu committed
340
        self.frozen_stages = frozen_stages
341
        self.normalize = normalize
ThangVu's avatar
ThangVu committed
342
        self.with_cp = with_cp
thangvu's avatar
thangvu committed
343
        self.norm_eval = norm_eval
Kai Chen's avatar
Kai Chen committed
344
345
        self.dcn = dcn
        self.stage_with_dcn = stage_with_dcn
Kai Chen's avatar
Kai Chen committed
346
347
        if dcn is not None:
            assert len(stage_with_dcn) == num_stages
ThangVu's avatar
ThangVu committed
348
        self.zero_init_residual = zero_init_residual
pangjm's avatar
pangjm committed
349
350
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
351
        self.inplanes = 64
pangjm's avatar
pangjm committed
352

thangvu's avatar
thangvu committed
353
        self._make_stem_layer()
Kai Chen's avatar
Kai Chen committed
354

Kai Chen's avatar
Kai Chen committed
355
        self.res_layers = []
pangjm's avatar
pangjm committed
356
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
357
358
            stride = strides[i]
            dilation = dilations[i]
Kai Chen's avatar
Kai Chen committed
359
            dcn = self.dcn if self.stage_with_dcn[i] else None
Kai Chen's avatar
Kai Chen committed
360
361
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
362
                self.block,
Kai Chen's avatar
Kai Chen committed
363
364
365
366
367
368
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
369
                with_cp=with_cp,
Kai Chen's avatar
Kai Chen committed
370
371
                normalize=normalize,
                dcn=dcn)
pangjm's avatar
pangjm committed
372
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
373
            layer_name = 'layer{}'.format(i + 1)
374
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
375
376
            self.res_layers.append(layer_name)

ThangVu's avatar
ThangVu committed
377
378
        self._freeze_stages()

pangjm's avatar
pangjm committed
379
380
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
381

ThangVu's avatar
ThangVu committed
382
383
384
385
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

thangvu's avatar
thangvu committed
386
387
388
    def _make_stem_layer(self):
        self.conv1 = nn.Conv2d(
            3, 64, kernel_size=7, stride=2, padding=3, bias=False)
Kai Chen's avatar
Kai Chen committed
389
390
        self.norm1_name, norm1 = build_norm_layer(
            self.normalize, 64, postfix=1)
ThangVu's avatar
ThangVu committed
391
        self.add_module(self.norm1_name, norm1)
thangvu's avatar
thangvu committed
392
393
394
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

ThangVu's avatar
ThangVu committed
395
    def _freeze_stages(self):
ThangVu's avatar
ThangVu committed
396
        if self.frozen_stages >= 0:
ThangVu's avatar
ThangVu committed
397
            for m in [self.conv1, self.norm1]:
ThangVu's avatar
ThangVu committed
398
                for param in m.parameters():
thangvu's avatar
thangvu committed
399
400
                    param.requires_grad = False

ThangVu's avatar
ThangVu committed
401
402
403
404
405
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
            for param in m.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
406
407
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
408
409
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
410
411
412
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
413
                    kaiming_init(m)
ThangVu's avatar
minor  
ThangVu committed
414
                elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
415
                    constant_init(m, 1)
416

Kai Chen's avatar
Kai Chen committed
417
418
419
420
421
422
            if self.dcn is not None:
                for m in self.modules():
                    if isinstance(m, Bottleneck) and hasattr(
                            m, 'conv2_offset'):
                        constant_init(m.conv2_offset, 0)

ThangVu's avatar
ThangVu committed
423
424
            if self.zero_init_residual:
                for m in self.modules():
ThangVu's avatar
ThangVu committed
425
426
427
428
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
Kai Chen's avatar
Kai Chen committed
429
430
431
432
433
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
ThangVu's avatar
ThangVu committed
434
        x = self.norm1(x)
Kai Chen's avatar
Kai Chen committed
435
436
437
438
439
440
441
442
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
myownskyW7's avatar
myownskyW7 committed
443
        return tuple(outs)
Kai Chen's avatar
Kai Chen committed
444
445
446

    def train(self, mode=True):
        super(ResNet, self).train(mode)
thangvu's avatar
thangvu committed
447
        if mode and self.norm_eval:
ThangVu's avatar
ThangVu committed
448
            for m in self.modules():
thangvu's avatar
thangvu committed
449
                # trick: eval have effect on BatchNorm only
ThangVu's avatar
ThangVu committed
450
451
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()