resnet.py 17.7 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
import torch.nn as nn
import torch.utils.checkpoint as cp
Kai Chen's avatar
Kai Chen committed
3
from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
4
from mmcv.runner import load_checkpoint
5
from torch.nn.modules.batchnorm import _BatchNorm
Kai Chen's avatar
Kai Chen committed
6

7
from mmdet.models.plugins import GeneralizedAttention
8
from mmdet.ops import ContextBlock, DeformConv, ModulatedDeformConv
Kai Chen's avatar
Kai Chen committed
9
from ..registry import BACKBONES
10
from ..utils import build_conv_layer, build_norm_layer
Kai Chen's avatar
Kai Chen committed
11
12
13
14
15
16
17
18
19
20
21


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
22
                 style='pytorch',
23
                 with_cp=False,
24
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
25
                 norm_cfg=dict(type='BN'),
26
                 dcn=None,
27
28
                 gcb=None,
                 gen_attention=None):
Kai Chen's avatar
Kai Chen committed
29
        super(BasicBlock, self).__init__()
pangjm's avatar
pangjm committed
30
        assert dcn is None, "Not implemented yet."
31
        assert gen_attention is None, "Not implemented yet."
32
        assert gcb is None, "Not implemented yet."
33

Kai Chen's avatar
Kai Chen committed
34
35
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
36

37
38
39
40
41
42
43
44
45
        self.conv1 = build_conv_layer(
            conv_cfg,
            inplanes,
            planes,
            3,
            stride=stride,
            padding=dilation,
            dilation=dilation,
            bias=False)
ThangVu's avatar
ThangVu committed
46
        self.add_module(self.norm1_name, norm1)
47
        self.conv2 = build_conv_layer(
48
            conv_cfg, planes, planes, 3, padding=1, bias=False)
ThangVu's avatar
ThangVu committed
49
        self.add_module(self.norm2_name, norm2)
50

Kai Chen's avatar
Kai Chen committed
51
52
53
54
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
55
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
56

ThangVu's avatar
ThangVu committed
57
58
59
60
61
62
63
64
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

Kai Chen's avatar
Kai Chen committed
65
    def forward(self, x):
pangjm's avatar
pangjm committed
66
        identity = x
Kai Chen's avatar
Kai Chen committed
67
68

        out = self.conv1(x)
ThangVu's avatar
ThangVu committed
69
        out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
70
71
72
        out = self.relu(out)

        out = self.conv2(out)
ThangVu's avatar
ThangVu committed
73
        out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
74
75

        if self.downsample is not None:
pangjm's avatar
pangjm committed
76
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
77

pangjm's avatar
pangjm committed
78
        out += identity
Kai Chen's avatar
Kai Chen committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
93
                 style='pytorch',
94
                 with_cp=False,
95
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
96
                 norm_cfg=dict(type='BN'),
97
                 dcn=None,
98
99
                 gcb=None,
                 gen_attention=None):
pangjm's avatar
pangjm committed
100
        """Bottleneck block for ResNet.
101
102
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
103
104
        """
        super(Bottleneck, self).__init__()
105
        assert style in ['pytorch', 'caffe']
Kai Chen's avatar
Kai Chen committed
106
        assert dcn is None or isinstance(dcn, dict)
107
        assert gcb is None or isinstance(gcb, dict)
108
109
        assert gen_attention is None or isinstance(gen_attention, dict)

pangjm's avatar
pangjm committed
110
111
        self.inplanes = inplanes
        self.planes = planes
Kai Chen's avatar
Kai Chen committed
112
113
114
115
        self.stride = stride
        self.dilation = dilation
        self.style = style
        self.with_cp = with_cp
116
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
117
        self.norm_cfg = norm_cfg
Kai Chen's avatar
Kai Chen committed
118
119
        self.dcn = dcn
        self.with_dcn = dcn is not None
120
121
        self.gcb = gcb
        self.with_gcb = gcb is not None
122
123
124
        self.gen_attention = gen_attention
        self.with_gen_attention = gen_attention is not None

Kai Chen's avatar
Kai Chen committed
125
        if self.style == 'pytorch':
pangjm's avatar
pangjm committed
126
127
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
128
        else:
pangjm's avatar
pangjm committed
129
130
            self.conv1_stride = stride
            self.conv2_stride = 1
131

Kai Chen's avatar
Kai Chen committed
132
133
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
Kai Chen's avatar
Kai Chen committed
134
        self.norm3_name, norm3 = build_norm_layer(
Kai Chen's avatar
Kai Chen committed
135
            norm_cfg, planes * self.expansion, postfix=3)
136

137
138
        self.conv1 = build_conv_layer(
            conv_cfg,
pangjm's avatar
pangjm committed
139
140
141
142
143
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
144
        self.add_module(self.norm1_name, norm1)
Kai Chen's avatar
Kai Chen committed
145
146
147
148
149
150
        fallback_on_stride = False
        self.with_modulated_dcn = False
        if self.with_dcn:
            fallback_on_stride = dcn.get('fallback_on_stride', False)
            self.with_modulated_dcn = dcn.get('modulated', False)
        if not self.with_dcn or fallback_on_stride:
151
152
            self.conv2 = build_conv_layer(
                conv_cfg,
Kai Chen's avatar
Kai Chen committed
153
154
155
156
157
158
159
160
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                bias=False)
        else:
161
            assert conv_cfg is None, 'conv_cfg must be None for DCN'
162
            self.deformable_groups = dcn.get('deformable_groups', 1)
Kai Chen's avatar
Kai Chen committed
163
164
165
166
167
168
169
170
            if not self.with_modulated_dcn:
                conv_op = DeformConv
                offset_channels = 18
            else:
                conv_op = ModulatedDeformConv
                offset_channels = 27
            self.conv2_offset = nn.Conv2d(
                planes,
171
                self.deformable_groups * offset_channels,
Kai Chen's avatar
Kai Chen committed
172
173
174
175
176
177
178
179
180
181
182
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation)
            self.conv2 = conv_op(
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
183
                deformable_groups=self.deformable_groups,
Kai Chen's avatar
Kai Chen committed
184
                bias=False)
ThangVu's avatar
ThangVu committed
185
        self.add_module(self.norm2_name, norm2)
186
187
188
189
190
191
        self.conv3 = build_conv_layer(
            conv_cfg,
            planes,
            planes * self.expansion,
            kernel_size=1,
            bias=False)
192
193
        self.add_module(self.norm3_name, norm3)

Kai Chen's avatar
Kai Chen committed
194
        self.relu = nn.ReLU(inplace=True)
195
        self.downsample = downsample
196

197
198
        if self.with_gcb:
            gcb_inplanes = planes * self.expansion
199
200
201
202
203
204
            self.context_block = ContextBlock(inplanes=gcb_inplanes, **gcb)

        # gen_attention
        if self.with_gen_attention:
            self.gen_attention_block = GeneralizedAttention(
                planes, **gen_attention)
Kai Chen's avatar
Kai Chen committed
205

ThangVu's avatar
ThangVu committed
206
207
208
209
210
211
212
213
214
215
216
217
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    @property
    def norm3(self):
        return getattr(self, self.norm3_name)

Kai Chen's avatar
Kai Chen committed
218
219
220
    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
221
            identity = x
Kai Chen's avatar
Kai Chen committed
222
223

            out = self.conv1(x)
ThangVu's avatar
ThangVu committed
224
            out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
225
226
            out = self.relu(out)

Kai Chen's avatar
Kai Chen committed
227
228
229
230
            if not self.with_dcn:
                out = self.conv2(out)
            elif self.with_modulated_dcn:
                offset_mask = self.conv2_offset(out)
231
232
233
                offset = offset_mask[:, :18 * self.deformable_groups, :, :]
                mask = offset_mask[:, -9 * self.deformable_groups:, :, :]
                mask = mask.sigmoid()
Kai Chen's avatar
Kai Chen committed
234
235
236
237
                out = self.conv2(out, offset, mask)
            else:
                offset = self.conv2_offset(out)
                out = self.conv2(out, offset)
ThangVu's avatar
ThangVu committed
238
            out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
239
240
            out = self.relu(out)

241
242
243
            if self.with_gen_attention:
                out = self.gen_attention_block(out)

Kai Chen's avatar
Kai Chen committed
244
            out = self.conv3(out)
ThangVu's avatar
ThangVu committed
245
            out = self.norm3(out)
Kai Chen's avatar
Kai Chen committed
246

247
248
249
            if self.with_gcb:
                out = self.context_block(out)

Kai Chen's avatar
Kai Chen committed
250
            if self.downsample is not None:
pangjm's avatar
pangjm committed
251
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
252

pangjm's avatar
pangjm committed
253
            out += identity
Kai Chen's avatar
Kai Chen committed
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
273
                   style='pytorch',
274
                   with_cp=False,
275
                   conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
276
                   norm_cfg=dict(type='BN'),
277
                   dcn=None,
278
279
280
                   gcb=None,
                   gen_attention=None,
                   gen_attention_blocks=[]):
Kai Chen's avatar
Kai Chen committed
281
282
283
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
284
285
            build_conv_layer(
                conv_cfg,
Kai Chen's avatar
Kai Chen committed
286
287
288
289
290
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
Kai Chen's avatar
Kai Chen committed
291
            build_norm_layer(norm_cfg, planes * block.expansion)[1],
Kai Chen's avatar
Kai Chen committed
292
293
294
295
296
        )

    layers = []
    layers.append(
        block(
Kai Chen's avatar
Kai Chen committed
297
298
299
300
301
            inplanes=inplanes,
            planes=planes,
            stride=stride,
            dilation=dilation,
            downsample=downsample,
Kai Chen's avatar
Kai Chen committed
302
            style=style,
303
            with_cp=with_cp,
304
            conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
305
            norm_cfg=norm_cfg,
306
            dcn=dcn,
307
308
309
            gcb=gcb,
            gen_attention=gen_attention if
            (0 in gen_attention_blocks) else None))
Kai Chen's avatar
Kai Chen committed
310
311
312
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
Kai Chen's avatar
Kai Chen committed
313
            block(
Kai Chen's avatar
Kai Chen committed
314
315
316
317
                inplanes=inplanes,
                planes=planes,
                stride=1,
                dilation=dilation,
Kai Chen's avatar
Kai Chen committed
318
319
                style=style,
                with_cp=with_cp,
320
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
321
                norm_cfg=norm_cfg,
322
                dcn=dcn,
323
324
325
                gcb=gcb,
                gen_attention=gen_attention if
                (i in gen_attention_blocks) else None))
Kai Chen's avatar
Kai Chen committed
326
327
328
329

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
330
@BACKBONES.register_module
Kai Chen's avatar
Kai Chen committed
331
332
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
333

Kai Chen's avatar
Kai Chen committed
334
335
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
336
        in_channels (int): Number of input image channels. Normally 3.
Kai Chen's avatar
Kai Chen committed
337
338
339
340
341
342
343
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
344
345
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any parameters.
Kai Chen's avatar
Kai Chen committed
346
        norm_cfg (dict): dictionary to construct and config norm layer.
thangvu's avatar
thangvu committed
347
348
349
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only.
Kai Chen's avatar
Kai Chen committed
350
351
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
thangvu's avatar
thangvu committed
352
353
        zero_init_residual (bool): whether to use zero init for last norm layer
            in resblocks to let them behave as identity.
354
355
356
357
358
359
360
361
362
363
364
365
366
367

    Example:
        >>> from mmdet.models import ResNet
        >>> import torch
        >>> self = ResNet(depth=18)
        >>> self.eval()
        >>> inputs = torch.rand(1, 3, 32, 32)
        >>> level_outputs = self.forward(inputs)
        >>> for level_out in level_outputs:
        ...     print(tuple(level_out.shape))
        (1, 64, 8, 8)
        (1, 128, 4, 4)
        (1, 256, 2, 2)
        (1, 512, 1, 1)
Kai Chen's avatar
Kai Chen committed
368
    """
Kai Chen's avatar
Kai Chen committed
369

Kai Chen's avatar
Kai Chen committed
370
371
372
373
374
375
376
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
377
378

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
379
                 depth,
380
                 in_channels=3,
Kai Chen's avatar
Kai Chen committed
381
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
382
383
384
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
385
                 style='pytorch',
ThangVu's avatar
ThangVu committed
386
                 frozen_stages=-1,
387
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
388
                 norm_cfg=dict(type='BN', requires_grad=True),
thangvu's avatar
thangvu committed
389
                 norm_eval=True,
Kai Chen's avatar
Kai Chen committed
390
391
                 dcn=None,
                 stage_with_dcn=(False, False, False, False),
392
393
                 gcb=None,
                 stage_with_gcb=(False, False, False, False),
394
395
                 gen_attention=None,
                 stage_with_gen_attention=((), (), (), ()),
ThangVu's avatar
ThangVu committed
396
397
                 with_cp=False,
                 zero_init_residual=True):
Kai Chen's avatar
Kai Chen committed
398
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
399
400
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
401
402
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
403
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
404
405
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
406
        assert len(strides) == len(dilations) == num_stages
Kai Chen's avatar
Kai Chen committed
407
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
408
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
409
        self.style = style
ThangVu's avatar
ThangVu committed
410
        self.frozen_stages = frozen_stages
411
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
412
        self.norm_cfg = norm_cfg
ThangVu's avatar
ThangVu committed
413
        self.with_cp = with_cp
thangvu's avatar
thangvu committed
414
        self.norm_eval = norm_eval
Kai Chen's avatar
Kai Chen committed
415
416
        self.dcn = dcn
        self.stage_with_dcn = stage_with_dcn
Kai Chen's avatar
Kai Chen committed
417
418
        if dcn is not None:
            assert len(stage_with_dcn) == num_stages
419
        self.gen_attention = gen_attention
420
421
422
423
        self.gcb = gcb
        self.stage_with_gcb = stage_with_gcb
        if gcb is not None:
            assert len(stage_with_gcb) == num_stages
ThangVu's avatar
ThangVu committed
424
        self.zero_init_residual = zero_init_residual
pangjm's avatar
pangjm committed
425
426
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
427
        self.inplanes = 64
pangjm's avatar
pangjm committed
428

429
        self._make_stem_layer(in_channels)
Kai Chen's avatar
Kai Chen committed
430

Kai Chen's avatar
Kai Chen committed
431
        self.res_layers = []
pangjm's avatar
pangjm committed
432
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
433
434
            stride = strides[i]
            dilation = dilations[i]
Kai Chen's avatar
Kai Chen committed
435
            dcn = self.dcn if self.stage_with_dcn[i] else None
436
            gcb = self.gcb if self.stage_with_gcb[i] else None
Kai Chen's avatar
Kai Chen committed
437
438
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
439
                self.block,
Kai Chen's avatar
Kai Chen committed
440
441
442
443
444
445
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
446
                with_cp=with_cp,
447
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
448
                norm_cfg=norm_cfg,
449
                dcn=dcn,
450
451
452
                gcb=gcb,
                gen_attention=gen_attention,
                gen_attention_blocks=stage_with_gen_attention[i])
pangjm's avatar
pangjm committed
453
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
454
            layer_name = 'layer{}'.format(i + 1)
455
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
456
457
            self.res_layers.append(layer_name)

ThangVu's avatar
ThangVu committed
458
459
        self._freeze_stages()

pangjm's avatar
pangjm committed
460
461
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
462

ThangVu's avatar
ThangVu committed
463
464
465
466
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

467
    def _make_stem_layer(self, in_channels):
468
469
        self.conv1 = build_conv_layer(
            self.conv_cfg,
470
            in_channels,
471
472
473
474
475
            64,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False)
Kai Chen's avatar
Kai Chen committed
476
        self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
ThangVu's avatar
ThangVu committed
477
        self.add_module(self.norm1_name, norm1)
thangvu's avatar
thangvu committed
478
479
480
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

ThangVu's avatar
ThangVu committed
481
    def _freeze_stages(self):
ThangVu's avatar
ThangVu committed
482
        if self.frozen_stages >= 0:
Kai Chen's avatar
Kai Chen committed
483
            self.norm1.eval()
ThangVu's avatar
ThangVu committed
484
            for m in [self.conv1, self.norm1]:
ThangVu's avatar
ThangVu committed
485
                for param in m.parameters():
thangvu's avatar
thangvu committed
486
487
                    param.requires_grad = False

ThangVu's avatar
ThangVu committed
488
489
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
Kai Chen's avatar
Kai Chen committed
490
            m.eval()
ThangVu's avatar
ThangVu committed
491
492
493
            for param in m.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
494
495
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
496
497
            from mmdet.apis import get_root_logger
            logger = get_root_logger()
498
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
499
500
501
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
502
                    kaiming_init(m)
503
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
504
                    constant_init(m, 1)
505

Kai Chen's avatar
Kai Chen committed
506
507
508
509
510
511
            if self.dcn is not None:
                for m in self.modules():
                    if isinstance(m, Bottleneck) and hasattr(
                            m, 'conv2_offset'):
                        constant_init(m.conv2_offset, 0)

ThangVu's avatar
ThangVu committed
512
513
            if self.zero_init_residual:
                for m in self.modules():
ThangVu's avatar
ThangVu committed
514
515
516
517
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
Kai Chen's avatar
Kai Chen committed
518
519
520
521
522
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
ThangVu's avatar
ThangVu committed
523
        x = self.norm1(x)
Kai Chen's avatar
Kai Chen committed
524
525
526
527
528
529
530
531
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
myownskyW7's avatar
myownskyW7 committed
532
        return tuple(outs)
Kai Chen's avatar
Kai Chen committed
533
534
535

    def train(self, mode=True):
        super(ResNet, self).train(mode)
536
        self._freeze_stages()
thangvu's avatar
thangvu committed
537
        if mode and self.norm_eval:
ThangVu's avatar
ThangVu committed
538
            for m in self.modules():
thangvu's avatar
thangvu committed
539
                # trick: eval have effect on BatchNorm only
540
                if isinstance(m, _BatchNorm):
ThangVu's avatar
ThangVu committed
541
                    m.eval()