resnet.py 16.5 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
import torch.nn as nn
import torch.utils.checkpoint as cp
Kai Chen's avatar
Kai Chen committed
3
from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
4
from mmcv.runner import load_checkpoint
5
from torch.nn.modules.batchnorm import _BatchNorm
Kai Chen's avatar
Kai Chen committed
6

7
from mmdet.models.plugins import GeneralizedAttention
8
from mmdet.ops import ContextBlock
Kai Chen's avatar
Kai Chen committed
9
from mmdet.utils import get_root_logger
Kai Chen's avatar
Kai Chen committed
10
from ..registry import BACKBONES
11
from ..utils import build_conv_layer, build_norm_layer
Kai Chen's avatar
Kai Chen committed
12
13
14
15
16
17
18
19
20
21
22


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
23
                 style='pytorch',
24
                 with_cp=False,
25
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
26
                 norm_cfg=dict(type='BN'),
27
                 dcn=None,
28
29
                 gcb=None,
                 gen_attention=None):
Kai Chen's avatar
Kai Chen committed
30
        super(BasicBlock, self).__init__()
pangjm's avatar
pangjm committed
31
        assert dcn is None, "Not implemented yet."
32
        assert gen_attention is None, "Not implemented yet."
33
        assert gcb is None, "Not implemented yet."
34

Kai Chen's avatar
Kai Chen committed
35
36
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
37

38
39
40
41
42
43
44
45
46
        self.conv1 = build_conv_layer(
            conv_cfg,
            inplanes,
            planes,
            3,
            stride=stride,
            padding=dilation,
            dilation=dilation,
            bias=False)
ThangVu's avatar
ThangVu committed
47
        self.add_module(self.norm1_name, norm1)
48
        self.conv2 = build_conv_layer(
49
            conv_cfg, planes, planes, 3, padding=1, bias=False)
ThangVu's avatar
ThangVu committed
50
        self.add_module(self.norm2_name, norm2)
51

Kai Chen's avatar
Kai Chen committed
52
53
54
55
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
56
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
57

ThangVu's avatar
ThangVu committed
58
59
60
61
62
63
64
65
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

Kai Chen's avatar
Kai Chen committed
66
    def forward(self, x):
pangjm's avatar
pangjm committed
67
        identity = x
Kai Chen's avatar
Kai Chen committed
68
69

        out = self.conv1(x)
ThangVu's avatar
ThangVu committed
70
        out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
71
72
73
        out = self.relu(out)

        out = self.conv2(out)
ThangVu's avatar
ThangVu committed
74
        out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
75
76

        if self.downsample is not None:
pangjm's avatar
pangjm committed
77
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
78

pangjm's avatar
pangjm committed
79
        out += identity
Kai Chen's avatar
Kai Chen committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
94
                 style='pytorch',
95
                 with_cp=False,
96
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
97
                 norm_cfg=dict(type='BN'),
98
                 dcn=None,
99
100
                 gcb=None,
                 gen_attention=None):
pangjm's avatar
pangjm committed
101
        """Bottleneck block for ResNet.
102
103
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
104
105
        """
        super(Bottleneck, self).__init__()
106
        assert style in ['pytorch', 'caffe']
Kai Chen's avatar
Kai Chen committed
107
        assert dcn is None or isinstance(dcn, dict)
108
        assert gcb is None or isinstance(gcb, dict)
109
110
        assert gen_attention is None or isinstance(gen_attention, dict)

pangjm's avatar
pangjm committed
111
112
        self.inplanes = inplanes
        self.planes = planes
Kai Chen's avatar
Kai Chen committed
113
114
115
116
        self.stride = stride
        self.dilation = dilation
        self.style = style
        self.with_cp = with_cp
117
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
118
        self.norm_cfg = norm_cfg
Kai Chen's avatar
Kai Chen committed
119
120
        self.dcn = dcn
        self.with_dcn = dcn is not None
121
122
        self.gcb = gcb
        self.with_gcb = gcb is not None
123
124
125
        self.gen_attention = gen_attention
        self.with_gen_attention = gen_attention is not None

Kai Chen's avatar
Kai Chen committed
126
        if self.style == 'pytorch':
pangjm's avatar
pangjm committed
127
128
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
129
        else:
pangjm's avatar
pangjm committed
130
131
            self.conv1_stride = stride
            self.conv2_stride = 1
132

Kai Chen's avatar
Kai Chen committed
133
134
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
Kai Chen's avatar
Kai Chen committed
135
        self.norm3_name, norm3 = build_norm_layer(
Kai Chen's avatar
Kai Chen committed
136
            norm_cfg, planes * self.expansion, postfix=3)
137

138
139
        self.conv1 = build_conv_layer(
            conv_cfg,
pangjm's avatar
pangjm committed
140
141
142
143
144
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
145
        self.add_module(self.norm1_name, norm1)
Kai Chen's avatar
Kai Chen committed
146
147
        fallback_on_stride = False
        if self.with_dcn:
148
            fallback_on_stride = dcn.pop('fallback_on_stride', False)
Kai Chen's avatar
Kai Chen committed
149
        if not self.with_dcn or fallback_on_stride:
150
151
            self.conv2 = build_conv_layer(
                conv_cfg,
Kai Chen's avatar
Kai Chen committed
152
153
154
155
156
157
158
159
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                bias=False)
        else:
160
161
162
            assert self.conv_cfg is None, 'conv_cfg cannot be None for DCN'
            self.conv2 = build_conv_layer(
                dcn,
Kai Chen's avatar
Kai Chen committed
163
164
165
166
167
168
169
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                bias=False)
170

ThangVu's avatar
ThangVu committed
171
        self.add_module(self.norm2_name, norm2)
172
173
174
175
176
177
        self.conv3 = build_conv_layer(
            conv_cfg,
            planes,
            planes * self.expansion,
            kernel_size=1,
            bias=False)
178
179
        self.add_module(self.norm3_name, norm3)

Kai Chen's avatar
Kai Chen committed
180
        self.relu = nn.ReLU(inplace=True)
181
        self.downsample = downsample
182

183
184
        if self.with_gcb:
            gcb_inplanes = planes * self.expansion
185
186
187
188
189
190
            self.context_block = ContextBlock(inplanes=gcb_inplanes, **gcb)

        # gen_attention
        if self.with_gen_attention:
            self.gen_attention_block = GeneralizedAttention(
                planes, **gen_attention)
Kai Chen's avatar
Kai Chen committed
191

ThangVu's avatar
ThangVu committed
192
193
194
195
196
197
198
199
200
201
202
203
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    @property
    def norm3(self):
        return getattr(self, self.norm3_name)

Kai Chen's avatar
Kai Chen committed
204
205
206
    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
207
            identity = x
Kai Chen's avatar
Kai Chen committed
208
209

            out = self.conv1(x)
ThangVu's avatar
ThangVu committed
210
            out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
211
212
            out = self.relu(out)

213
            out = self.conv2(out)
ThangVu's avatar
ThangVu committed
214
            out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
215
216
            out = self.relu(out)

217
218
219
            if self.with_gen_attention:
                out = self.gen_attention_block(out)

Kai Chen's avatar
Kai Chen committed
220
            out = self.conv3(out)
ThangVu's avatar
ThangVu committed
221
            out = self.norm3(out)
Kai Chen's avatar
Kai Chen committed
222

223
224
225
            if self.with_gcb:
                out = self.context_block(out)

Kai Chen's avatar
Kai Chen committed
226
            if self.downsample is not None:
pangjm's avatar
pangjm committed
227
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
228

pangjm's avatar
pangjm committed
229
            out += identity
Kai Chen's avatar
Kai Chen committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
249
                   style='pytorch',
250
                   with_cp=False,
251
                   conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
252
                   norm_cfg=dict(type='BN'),
253
                   dcn=None,
254
255
256
                   gcb=None,
                   gen_attention=None,
                   gen_attention_blocks=[]):
Kai Chen's avatar
Kai Chen committed
257
258
259
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
260
261
            build_conv_layer(
                conv_cfg,
Kai Chen's avatar
Kai Chen committed
262
263
264
265
266
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
Kai Chen's avatar
Kai Chen committed
267
            build_norm_layer(norm_cfg, planes * block.expansion)[1],
Kai Chen's avatar
Kai Chen committed
268
269
270
271
272
        )

    layers = []
    layers.append(
        block(
Kai Chen's avatar
Kai Chen committed
273
274
275
276
277
            inplanes=inplanes,
            planes=planes,
            stride=stride,
            dilation=dilation,
            downsample=downsample,
Kai Chen's avatar
Kai Chen committed
278
            style=style,
279
            with_cp=with_cp,
280
            conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
281
            norm_cfg=norm_cfg,
282
            dcn=dcn,
283
284
285
            gcb=gcb,
            gen_attention=gen_attention if
            (0 in gen_attention_blocks) else None))
Kai Chen's avatar
Kai Chen committed
286
287
288
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
Kai Chen's avatar
Kai Chen committed
289
            block(
Kai Chen's avatar
Kai Chen committed
290
291
292
293
                inplanes=inplanes,
                planes=planes,
                stride=1,
                dilation=dilation,
Kai Chen's avatar
Kai Chen committed
294
295
                style=style,
                with_cp=with_cp,
296
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
297
                norm_cfg=norm_cfg,
298
                dcn=dcn,
299
300
301
                gcb=gcb,
                gen_attention=gen_attention if
                (i in gen_attention_blocks) else None))
Kai Chen's avatar
Kai Chen committed
302
303
304
305

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
306
@BACKBONES.register_module
Kai Chen's avatar
Kai Chen committed
307
308
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
309

Kai Chen's avatar
Kai Chen committed
310
311
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
312
        in_channels (int): Number of input image channels. Normally 3.
Kai Chen's avatar
Kai Chen committed
313
314
315
316
317
318
319
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
320
321
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any parameters.
Kai Chen's avatar
Kai Chen committed
322
        norm_cfg (dict): dictionary to construct and config norm layer.
thangvu's avatar
thangvu committed
323
324
325
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only.
Kai Chen's avatar
Kai Chen committed
326
327
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
thangvu's avatar
thangvu committed
328
329
        zero_init_residual (bool): whether to use zero init for last norm layer
            in resblocks to let them behave as identity.
330
331
332
333
334
335
336
337
338
339
340
341
342
343

    Example:
        >>> from mmdet.models import ResNet
        >>> import torch
        >>> self = ResNet(depth=18)
        >>> self.eval()
        >>> inputs = torch.rand(1, 3, 32, 32)
        >>> level_outputs = self.forward(inputs)
        >>> for level_out in level_outputs:
        ...     print(tuple(level_out.shape))
        (1, 64, 8, 8)
        (1, 128, 4, 4)
        (1, 256, 2, 2)
        (1, 512, 1, 1)
Kai Chen's avatar
Kai Chen committed
344
    """
Kai Chen's avatar
Kai Chen committed
345

Kai Chen's avatar
Kai Chen committed
346
347
348
349
350
351
352
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
353
354

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
355
                 depth,
356
                 in_channels=3,
Kai Chen's avatar
Kai Chen committed
357
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
358
359
360
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
361
                 style='pytorch',
ThangVu's avatar
ThangVu committed
362
                 frozen_stages=-1,
363
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
364
                 norm_cfg=dict(type='BN', requires_grad=True),
thangvu's avatar
thangvu committed
365
                 norm_eval=True,
Kai Chen's avatar
Kai Chen committed
366
367
                 dcn=None,
                 stage_with_dcn=(False, False, False, False),
368
369
                 gcb=None,
                 stage_with_gcb=(False, False, False, False),
370
371
                 gen_attention=None,
                 stage_with_gen_attention=((), (), (), ()),
ThangVu's avatar
ThangVu committed
372
373
                 with_cp=False,
                 zero_init_residual=True):
Kai Chen's avatar
Kai Chen committed
374
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
375
376
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
377
378
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
379
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
380
381
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
382
        assert len(strides) == len(dilations) == num_stages
Kai Chen's avatar
Kai Chen committed
383
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
384
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
385
        self.style = style
ThangVu's avatar
ThangVu committed
386
        self.frozen_stages = frozen_stages
387
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
388
        self.norm_cfg = norm_cfg
ThangVu's avatar
ThangVu committed
389
        self.with_cp = with_cp
thangvu's avatar
thangvu committed
390
        self.norm_eval = norm_eval
Kai Chen's avatar
Kai Chen committed
391
392
        self.dcn = dcn
        self.stage_with_dcn = stage_with_dcn
Kai Chen's avatar
Kai Chen committed
393
394
        if dcn is not None:
            assert len(stage_with_dcn) == num_stages
395
        self.gen_attention = gen_attention
396
397
398
399
        self.gcb = gcb
        self.stage_with_gcb = stage_with_gcb
        if gcb is not None:
            assert len(stage_with_gcb) == num_stages
ThangVu's avatar
ThangVu committed
400
        self.zero_init_residual = zero_init_residual
pangjm's avatar
pangjm committed
401
402
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
403
        self.inplanes = 64
pangjm's avatar
pangjm committed
404

405
        self._make_stem_layer(in_channels)
Kai Chen's avatar
Kai Chen committed
406

Kai Chen's avatar
Kai Chen committed
407
        self.res_layers = []
pangjm's avatar
pangjm committed
408
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
409
410
            stride = strides[i]
            dilation = dilations[i]
Kai Chen's avatar
Kai Chen committed
411
            dcn = self.dcn if self.stage_with_dcn[i] else None
412
            gcb = self.gcb if self.stage_with_gcb[i] else None
Kai Chen's avatar
Kai Chen committed
413
414
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
415
                self.block,
Kai Chen's avatar
Kai Chen committed
416
417
418
419
420
421
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
422
                with_cp=with_cp,
423
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
424
                norm_cfg=norm_cfg,
425
                dcn=dcn,
426
427
428
                gcb=gcb,
                gen_attention=gen_attention,
                gen_attention_blocks=stage_with_gen_attention[i])
pangjm's avatar
pangjm committed
429
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
430
            layer_name = 'layer{}'.format(i + 1)
431
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
432
433
            self.res_layers.append(layer_name)

ThangVu's avatar
ThangVu committed
434
435
        self._freeze_stages()

pangjm's avatar
pangjm committed
436
437
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
438

ThangVu's avatar
ThangVu committed
439
440
441
442
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

443
    def _make_stem_layer(self, in_channels):
444
445
        self.conv1 = build_conv_layer(
            self.conv_cfg,
446
            in_channels,
447
448
449
450
451
            64,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False)
Kai Chen's avatar
Kai Chen committed
452
        self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
ThangVu's avatar
ThangVu committed
453
        self.add_module(self.norm1_name, norm1)
thangvu's avatar
thangvu committed
454
455
456
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

ThangVu's avatar
ThangVu committed
457
    def _freeze_stages(self):
ThangVu's avatar
ThangVu committed
458
        if self.frozen_stages >= 0:
Kai Chen's avatar
Kai Chen committed
459
            self.norm1.eval()
ThangVu's avatar
ThangVu committed
460
            for m in [self.conv1, self.norm1]:
ThangVu's avatar
ThangVu committed
461
                for param in m.parameters():
thangvu's avatar
thangvu committed
462
463
                    param.requires_grad = False

ThangVu's avatar
ThangVu committed
464
465
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
Kai Chen's avatar
Kai Chen committed
466
            m.eval()
ThangVu's avatar
ThangVu committed
467
468
469
            for param in m.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
470
471
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
472
            logger = get_root_logger()
473
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
474
475
476
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
477
                    kaiming_init(m)
478
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
479
                    constant_init(m, 1)
480

Kai Chen's avatar
Kai Chen committed
481
482
483
484
485
486
            if self.dcn is not None:
                for m in self.modules():
                    if isinstance(m, Bottleneck) and hasattr(
                            m, 'conv2_offset'):
                        constant_init(m.conv2_offset, 0)

ThangVu's avatar
ThangVu committed
487
488
            if self.zero_init_residual:
                for m in self.modules():
ThangVu's avatar
ThangVu committed
489
490
491
492
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
Kai Chen's avatar
Kai Chen committed
493
494
495
496
497
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
ThangVu's avatar
ThangVu committed
498
        x = self.norm1(x)
Kai Chen's avatar
Kai Chen committed
499
500
501
502
503
504
505
506
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
myownskyW7's avatar
myownskyW7 committed
507
        return tuple(outs)
Kai Chen's avatar
Kai Chen committed
508
509
510

    def train(self, mode=True):
        super(ResNet, self).train(mode)
511
        self._freeze_stages()
thangvu's avatar
thangvu committed
512
        if mode and self.norm_eval:
ThangVu's avatar
ThangVu committed
513
            for m in self.modules():
thangvu's avatar
thangvu committed
514
                # trick: eval have effect on BatchNorm only
515
                if isinstance(m, _BatchNorm):
ThangVu's avatar
ThangVu committed
516
                    m.eval()