resnet.py 17 KB
Newer Older
1
2
import logging

Kai Chen's avatar
Kai Chen committed
3
4
import torch.nn as nn
import torch.utils.checkpoint as cp
Kai Chen's avatar
Kai Chen committed
5
from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
6
from mmcv.runner import load_checkpoint
7
from torch.nn.modules.batchnorm import _BatchNorm
Kai Chen's avatar
Kai Chen committed
8

9
from mmdet.models.plugins import GeneralizedAttention
10
from mmdet.ops import ContextBlock, DeformConv, ModulatedDeformConv
Kai Chen's avatar
Kai Chen committed
11
from ..registry import BACKBONES
12
from ..utils import build_conv_layer, build_norm_layer
Kai Chen's avatar
Kai Chen committed
13
14
15
16
17
18
19
20
21
22
23


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
24
                 style='pytorch',
25
                 with_cp=False,
26
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
27
                 norm_cfg=dict(type='BN'),
28
                 dcn=None,
29
30
                 gcb=None,
                 gen_attention=None):
Kai Chen's avatar
Kai Chen committed
31
        super(BasicBlock, self).__init__()
pangjm's avatar
pangjm committed
32
        assert dcn is None, "Not implemented yet."
33
        assert gen_attention is None, "Not implemented yet."
34
        assert gcb is None, "Not implemented yet."
35

Kai Chen's avatar
Kai Chen committed
36
37
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
38

39
40
41
42
43
44
45
46
47
        self.conv1 = build_conv_layer(
            conv_cfg,
            inplanes,
            planes,
            3,
            stride=stride,
            padding=dilation,
            dilation=dilation,
            bias=False)
ThangVu's avatar
ThangVu committed
48
        self.add_module(self.norm1_name, norm1)
49
        self.conv2 = build_conv_layer(
50
            conv_cfg, planes, planes, 3, padding=1, bias=False)
ThangVu's avatar
ThangVu committed
51
        self.add_module(self.norm2_name, norm2)
52

Kai Chen's avatar
Kai Chen committed
53
54
55
56
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
57
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
58

ThangVu's avatar
ThangVu committed
59
60
61
62
63
64
65
66
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

Kai Chen's avatar
Kai Chen committed
67
    def forward(self, x):
pangjm's avatar
pangjm committed
68
        identity = x
Kai Chen's avatar
Kai Chen committed
69
70

        out = self.conv1(x)
ThangVu's avatar
ThangVu committed
71
        out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
72
73
74
        out = self.relu(out)

        out = self.conv2(out)
ThangVu's avatar
ThangVu committed
75
        out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
76
77

        if self.downsample is not None:
pangjm's avatar
pangjm committed
78
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
79

pangjm's avatar
pangjm committed
80
        out += identity
Kai Chen's avatar
Kai Chen committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
95
                 style='pytorch',
96
                 with_cp=False,
97
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
98
                 norm_cfg=dict(type='BN'),
99
                 dcn=None,
100
101
                 gcb=None,
                 gen_attention=None):
pangjm's avatar
pangjm committed
102
        """Bottleneck block for ResNet.
103
104
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
105
106
        """
        super(Bottleneck, self).__init__()
107
        assert style in ['pytorch', 'caffe']
Kai Chen's avatar
Kai Chen committed
108
        assert dcn is None or isinstance(dcn, dict)
109
        assert gcb is None or isinstance(gcb, dict)
110
111
        assert gen_attention is None or isinstance(gen_attention, dict)

pangjm's avatar
pangjm committed
112
113
        self.inplanes = inplanes
        self.planes = planes
Kai Chen's avatar
Kai Chen committed
114
115
116
117
        self.stride = stride
        self.dilation = dilation
        self.style = style
        self.with_cp = with_cp
118
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
119
        self.norm_cfg = norm_cfg
Kai Chen's avatar
Kai Chen committed
120
121
        self.dcn = dcn
        self.with_dcn = dcn is not None
122
123
        self.gcb = gcb
        self.with_gcb = gcb is not None
124
125
126
        self.gen_attention = gen_attention
        self.with_gen_attention = gen_attention is not None

Kai Chen's avatar
Kai Chen committed
127
        if self.style == 'pytorch':
pangjm's avatar
pangjm committed
128
129
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
130
        else:
pangjm's avatar
pangjm committed
131
132
            self.conv1_stride = stride
            self.conv2_stride = 1
133

Kai Chen's avatar
Kai Chen committed
134
135
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
Kai Chen's avatar
Kai Chen committed
136
        self.norm3_name, norm3 = build_norm_layer(
Kai Chen's avatar
Kai Chen committed
137
            norm_cfg, planes * self.expansion, postfix=3)
138

139
140
        self.conv1 = build_conv_layer(
            conv_cfg,
pangjm's avatar
pangjm committed
141
142
143
144
145
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
146
        self.add_module(self.norm1_name, norm1)
Kai Chen's avatar
Kai Chen committed
147
148
149
150
151
152
        fallback_on_stride = False
        self.with_modulated_dcn = False
        if self.with_dcn:
            fallback_on_stride = dcn.get('fallback_on_stride', False)
            self.with_modulated_dcn = dcn.get('modulated', False)
        if not self.with_dcn or fallback_on_stride:
153
154
            self.conv2 = build_conv_layer(
                conv_cfg,
Kai Chen's avatar
Kai Chen committed
155
156
157
158
159
160
161
162
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                bias=False)
        else:
163
            assert conv_cfg is None, 'conv_cfg must be None for DCN'
Kai Chen's avatar
Kai Chen committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
            deformable_groups = dcn.get('deformable_groups', 1)
            if not self.with_modulated_dcn:
                conv_op = DeformConv
                offset_channels = 18
            else:
                conv_op = ModulatedDeformConv
                offset_channels = 27
            self.conv2_offset = nn.Conv2d(
                planes,
                deformable_groups * offset_channels,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation)
            self.conv2 = conv_op(
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                deformable_groups=deformable_groups,
                bias=False)
ThangVu's avatar
ThangVu committed
187
        self.add_module(self.norm2_name, norm2)
188
189
190
191
192
193
        self.conv3 = build_conv_layer(
            conv_cfg,
            planes,
            planes * self.expansion,
            kernel_size=1,
            bias=False)
194
195
        self.add_module(self.norm3_name, norm3)

Kai Chen's avatar
Kai Chen committed
196
        self.relu = nn.ReLU(inplace=True)
197
        self.downsample = downsample
198

199
200
        if self.with_gcb:
            gcb_inplanes = planes * self.expansion
201
202
203
204
205
206
            self.context_block = ContextBlock(inplanes=gcb_inplanes, **gcb)

        # gen_attention
        if self.with_gen_attention:
            self.gen_attention_block = GeneralizedAttention(
                planes, **gen_attention)
Kai Chen's avatar
Kai Chen committed
207

ThangVu's avatar
ThangVu committed
208
209
210
211
212
213
214
215
216
217
218
219
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    @property
    def norm3(self):
        return getattr(self, self.norm3_name)

Kai Chen's avatar
Kai Chen committed
220
221
222
    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
223
            identity = x
Kai Chen's avatar
Kai Chen committed
224
225

            out = self.conv1(x)
ThangVu's avatar
ThangVu committed
226
            out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
227
228
            out = self.relu(out)

Kai Chen's avatar
Kai Chen committed
229
230
231
232
233
234
235
236
237
238
            if not self.with_dcn:
                out = self.conv2(out)
            elif self.with_modulated_dcn:
                offset_mask = self.conv2_offset(out)
                offset = offset_mask[:, :18, :, :]
                mask = offset_mask[:, -9:, :, :].sigmoid()
                out = self.conv2(out, offset, mask)
            else:
                offset = self.conv2_offset(out)
                out = self.conv2(out, offset)
ThangVu's avatar
ThangVu committed
239
            out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
240
241
            out = self.relu(out)

242
243
244
            if self.with_gen_attention:
                out = self.gen_attention_block(out)

Kai Chen's avatar
Kai Chen committed
245
            out = self.conv3(out)
ThangVu's avatar
ThangVu committed
246
            out = self.norm3(out)
Kai Chen's avatar
Kai Chen committed
247

248
249
250
            if self.with_gcb:
                out = self.context_block(out)

Kai Chen's avatar
Kai Chen committed
251
            if self.downsample is not None:
pangjm's avatar
pangjm committed
252
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
253

pangjm's avatar
pangjm committed
254
            out += identity
Kai Chen's avatar
Kai Chen committed
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
274
                   style='pytorch',
275
                   with_cp=False,
276
                   conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
277
                   norm_cfg=dict(type='BN'),
278
                   dcn=None,
279
280
281
                   gcb=None,
                   gen_attention=None,
                   gen_attention_blocks=[]):
Kai Chen's avatar
Kai Chen committed
282
283
284
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
285
286
            build_conv_layer(
                conv_cfg,
Kai Chen's avatar
Kai Chen committed
287
288
289
290
291
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
Kai Chen's avatar
Kai Chen committed
292
            build_norm_layer(norm_cfg, planes * block.expansion)[1],
Kai Chen's avatar
Kai Chen committed
293
294
295
296
297
        )

    layers = []
    layers.append(
        block(
Kai Chen's avatar
Kai Chen committed
298
299
300
301
302
            inplanes=inplanes,
            planes=planes,
            stride=stride,
            dilation=dilation,
            downsample=downsample,
Kai Chen's avatar
Kai Chen committed
303
            style=style,
304
            with_cp=with_cp,
305
            conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
306
            norm_cfg=norm_cfg,
307
            dcn=dcn,
308
309
310
            gcb=gcb,
            gen_attention=gen_attention if
            (0 in gen_attention_blocks) else None))
Kai Chen's avatar
Kai Chen committed
311
312
313
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
Kai Chen's avatar
Kai Chen committed
314
            block(
Kai Chen's avatar
Kai Chen committed
315
316
317
318
                inplanes=inplanes,
                planes=planes,
                stride=1,
                dilation=dilation,
Kai Chen's avatar
Kai Chen committed
319
320
                style=style,
                with_cp=with_cp,
321
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
322
                norm_cfg=norm_cfg,
323
                dcn=dcn,
324
325
326
                gcb=gcb,
                gen_attention=gen_attention if
                (i in gen_attention_blocks) else None))
Kai Chen's avatar
Kai Chen committed
327
328
329
330

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
331
@BACKBONES.register_module
Kai Chen's avatar
Kai Chen committed
332
333
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
334

Kai Chen's avatar
Kai Chen committed
335
336
337
338
339
340
341
342
343
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
344
345
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any parameters.
Kai Chen's avatar
Kai Chen committed
346
        norm_cfg (dict): dictionary to construct and config norm layer.
thangvu's avatar
thangvu committed
347
348
349
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only.
Kai Chen's avatar
Kai Chen committed
350
351
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
thangvu's avatar
thangvu committed
352
353
        zero_init_residual (bool): whether to use zero init for last norm layer
            in resblocks to let them behave as identity.
Kai Chen's avatar
Kai Chen committed
354
    """
Kai Chen's avatar
Kai Chen committed
355

Kai Chen's avatar
Kai Chen committed
356
357
358
359
360
361
362
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
363
364

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
365
366
                 depth,
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
367
368
369
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
370
                 style='pytorch',
ThangVu's avatar
ThangVu committed
371
                 frozen_stages=-1,
372
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
373
                 norm_cfg=dict(type='BN', requires_grad=True),
thangvu's avatar
thangvu committed
374
                 norm_eval=True,
Kai Chen's avatar
Kai Chen committed
375
376
                 dcn=None,
                 stage_with_dcn=(False, False, False, False),
377
378
                 gcb=None,
                 stage_with_gcb=(False, False, False, False),
379
380
                 gen_attention=None,
                 stage_with_gen_attention=((), (), (), ()),
ThangVu's avatar
ThangVu committed
381
382
                 with_cp=False,
                 zero_init_residual=True):
Kai Chen's avatar
Kai Chen committed
383
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
384
385
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
386
387
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
388
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
389
390
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
391
        assert len(strides) == len(dilations) == num_stages
Kai Chen's avatar
Kai Chen committed
392
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
393
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
394
        self.style = style
ThangVu's avatar
ThangVu committed
395
        self.frozen_stages = frozen_stages
396
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
397
        self.norm_cfg = norm_cfg
ThangVu's avatar
ThangVu committed
398
        self.with_cp = with_cp
thangvu's avatar
thangvu committed
399
        self.norm_eval = norm_eval
Kai Chen's avatar
Kai Chen committed
400
401
        self.dcn = dcn
        self.stage_with_dcn = stage_with_dcn
Kai Chen's avatar
Kai Chen committed
402
403
        if dcn is not None:
            assert len(stage_with_dcn) == num_stages
404
        self.gen_attention = gen_attention
405
406
407
408
        self.gcb = gcb
        self.stage_with_gcb = stage_with_gcb
        if gcb is not None:
            assert len(stage_with_gcb) == num_stages
ThangVu's avatar
ThangVu committed
409
        self.zero_init_residual = zero_init_residual
pangjm's avatar
pangjm committed
410
411
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
412
        self.inplanes = 64
pangjm's avatar
pangjm committed
413

thangvu's avatar
thangvu committed
414
        self._make_stem_layer()
Kai Chen's avatar
Kai Chen committed
415

Kai Chen's avatar
Kai Chen committed
416
        self.res_layers = []
pangjm's avatar
pangjm committed
417
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
418
419
            stride = strides[i]
            dilation = dilations[i]
Kai Chen's avatar
Kai Chen committed
420
            dcn = self.dcn if self.stage_with_dcn[i] else None
421
            gcb = self.gcb if self.stage_with_gcb[i] else None
Kai Chen's avatar
Kai Chen committed
422
423
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
424
                self.block,
Kai Chen's avatar
Kai Chen committed
425
426
427
428
429
430
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
431
                with_cp=with_cp,
432
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
433
                norm_cfg=norm_cfg,
434
                dcn=dcn,
435
436
437
                gcb=gcb,
                gen_attention=gen_attention,
                gen_attention_blocks=stage_with_gen_attention[i])
pangjm's avatar
pangjm committed
438
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
439
            layer_name = 'layer{}'.format(i + 1)
440
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
441
442
            self.res_layers.append(layer_name)

ThangVu's avatar
ThangVu committed
443
444
        self._freeze_stages()

pangjm's avatar
pangjm committed
445
446
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
447

ThangVu's avatar
ThangVu committed
448
449
450
451
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

thangvu's avatar
thangvu committed
452
    def _make_stem_layer(self):
453
454
455
456
457
458
459
460
        self.conv1 = build_conv_layer(
            self.conv_cfg,
            3,
            64,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False)
Kai Chen's avatar
Kai Chen committed
461
        self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
ThangVu's avatar
ThangVu committed
462
        self.add_module(self.norm1_name, norm1)
thangvu's avatar
thangvu committed
463
464
465
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

ThangVu's avatar
ThangVu committed
466
    def _freeze_stages(self):
ThangVu's avatar
ThangVu committed
467
        if self.frozen_stages >= 0:
Kai Chen's avatar
Kai Chen committed
468
            self.norm1.eval()
ThangVu's avatar
ThangVu committed
469
            for m in [self.conv1, self.norm1]:
ThangVu's avatar
ThangVu committed
470
                for param in m.parameters():
thangvu's avatar
thangvu committed
471
472
                    param.requires_grad = False

ThangVu's avatar
ThangVu committed
473
474
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
Kai Chen's avatar
Kai Chen committed
475
            m.eval()
ThangVu's avatar
ThangVu committed
476
477
478
            for param in m.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
479
480
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
481
482
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
483
484
485
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
486
                    kaiming_init(m)
487
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
488
                    constant_init(m, 1)
489

Kai Chen's avatar
Kai Chen committed
490
491
492
493
494
495
            if self.dcn is not None:
                for m in self.modules():
                    if isinstance(m, Bottleneck) and hasattr(
                            m, 'conv2_offset'):
                        constant_init(m.conv2_offset, 0)

ThangVu's avatar
ThangVu committed
496
497
            if self.zero_init_residual:
                for m in self.modules():
ThangVu's avatar
ThangVu committed
498
499
500
501
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
Kai Chen's avatar
Kai Chen committed
502
503
504
505
506
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
ThangVu's avatar
ThangVu committed
507
        x = self.norm1(x)
Kai Chen's avatar
Kai Chen committed
508
509
510
511
512
513
514
515
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
myownskyW7's avatar
myownskyW7 committed
516
        return tuple(outs)
Kai Chen's avatar
Kai Chen committed
517
518
519

    def train(self, mode=True):
        super(ResNet, self).train(mode)
520
        self._freeze_stages()
thangvu's avatar
thangvu committed
521
        if mode and self.norm_eval:
ThangVu's avatar
ThangVu committed
522
            for m in self.modules():
thangvu's avatar
thangvu committed
523
                # trick: eval have effect on BatchNorm only
524
                if isinstance(m, _BatchNorm):
ThangVu's avatar
ThangVu committed
525
                    m.eval()