resnet.py 16.5 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
import torch.nn as nn
import torch.utils.checkpoint as cp
Kai Chen's avatar
Kai Chen committed
3
from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
4
from mmcv.runner import load_checkpoint
5
from torch.nn.modules.batchnorm import _BatchNorm
Kai Chen's avatar
Kai Chen committed
6

7
from mmdet.models.plugins import GeneralizedAttention
8
from mmdet.ops import ContextBlock
Kai Chen's avatar
Kai Chen committed
9
from ..registry import BACKBONES
10
from ..utils import build_conv_layer, build_norm_layer
Kai Chen's avatar
Kai Chen committed
11
12
13
14
15
16
17
18
19
20
21


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
22
                 style='pytorch',
23
                 with_cp=False,
24
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
25
                 norm_cfg=dict(type='BN'),
26
                 dcn=None,
27
28
                 gcb=None,
                 gen_attention=None):
Kai Chen's avatar
Kai Chen committed
29
        super(BasicBlock, self).__init__()
pangjm's avatar
pangjm committed
30
        assert dcn is None, "Not implemented yet."
31
        assert gen_attention is None, "Not implemented yet."
32
        assert gcb is None, "Not implemented yet."
33

Kai Chen's avatar
Kai Chen committed
34
35
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
36

37
38
39
40
41
42
43
44
45
        self.conv1 = build_conv_layer(
            conv_cfg,
            inplanes,
            planes,
            3,
            stride=stride,
            padding=dilation,
            dilation=dilation,
            bias=False)
ThangVu's avatar
ThangVu committed
46
        self.add_module(self.norm1_name, norm1)
47
        self.conv2 = build_conv_layer(
48
            conv_cfg, planes, planes, 3, padding=1, bias=False)
ThangVu's avatar
ThangVu committed
49
        self.add_module(self.norm2_name, norm2)
50

Kai Chen's avatar
Kai Chen committed
51
52
53
54
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
55
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
56

ThangVu's avatar
ThangVu committed
57
58
59
60
61
62
63
64
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

Kai Chen's avatar
Kai Chen committed
65
    def forward(self, x):
pangjm's avatar
pangjm committed
66
        identity = x
Kai Chen's avatar
Kai Chen committed
67
68

        out = self.conv1(x)
ThangVu's avatar
ThangVu committed
69
        out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
70
71
72
        out = self.relu(out)

        out = self.conv2(out)
ThangVu's avatar
ThangVu committed
73
        out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
74
75

        if self.downsample is not None:
pangjm's avatar
pangjm committed
76
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
77

pangjm's avatar
pangjm committed
78
        out += identity
Kai Chen's avatar
Kai Chen committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
93
                 style='pytorch',
94
                 with_cp=False,
95
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
96
                 norm_cfg=dict(type='BN'),
97
                 dcn=None,
98
99
                 gcb=None,
                 gen_attention=None):
pangjm's avatar
pangjm committed
100
        """Bottleneck block for ResNet.
101
102
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
103
104
        """
        super(Bottleneck, self).__init__()
105
        assert style in ['pytorch', 'caffe']
Kai Chen's avatar
Kai Chen committed
106
        assert dcn is None or isinstance(dcn, dict)
107
        assert gcb is None or isinstance(gcb, dict)
108
109
        assert gen_attention is None or isinstance(gen_attention, dict)

pangjm's avatar
pangjm committed
110
111
        self.inplanes = inplanes
        self.planes = planes
Kai Chen's avatar
Kai Chen committed
112
113
114
115
        self.stride = stride
        self.dilation = dilation
        self.style = style
        self.with_cp = with_cp
116
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
117
        self.norm_cfg = norm_cfg
Kai Chen's avatar
Kai Chen committed
118
119
        self.dcn = dcn
        self.with_dcn = dcn is not None
120
121
        self.gcb = gcb
        self.with_gcb = gcb is not None
122
123
124
        self.gen_attention = gen_attention
        self.with_gen_attention = gen_attention is not None

Kai Chen's avatar
Kai Chen committed
125
        if self.style == 'pytorch':
pangjm's avatar
pangjm committed
126
127
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
128
        else:
pangjm's avatar
pangjm committed
129
130
            self.conv1_stride = stride
            self.conv2_stride = 1
131

Kai Chen's avatar
Kai Chen committed
132
133
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
Kai Chen's avatar
Kai Chen committed
134
        self.norm3_name, norm3 = build_norm_layer(
Kai Chen's avatar
Kai Chen committed
135
            norm_cfg, planes * self.expansion, postfix=3)
136

137
138
        self.conv1 = build_conv_layer(
            conv_cfg,
pangjm's avatar
pangjm committed
139
140
141
142
143
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
144
        self.add_module(self.norm1_name, norm1)
Kai Chen's avatar
Kai Chen committed
145
146
        fallback_on_stride = False
        if self.with_dcn:
147
            fallback_on_stride = dcn.pop('fallback_on_stride', False)
Kai Chen's avatar
Kai Chen committed
148
        if not self.with_dcn or fallback_on_stride:
149
150
            self.conv2 = build_conv_layer(
                conv_cfg,
Kai Chen's avatar
Kai Chen committed
151
152
153
154
155
156
157
158
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                bias=False)
        else:
159
160
161
            assert self.conv_cfg is None, 'conv_cfg cannot be None for DCN'
            self.conv2 = build_conv_layer(
                dcn,
Kai Chen's avatar
Kai Chen committed
162
163
164
165
166
167
168
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                bias=False)
169

ThangVu's avatar
ThangVu committed
170
        self.add_module(self.norm2_name, norm2)
171
172
173
174
175
176
        self.conv3 = build_conv_layer(
            conv_cfg,
            planes,
            planes * self.expansion,
            kernel_size=1,
            bias=False)
177
178
        self.add_module(self.norm3_name, norm3)

Kai Chen's avatar
Kai Chen committed
179
        self.relu = nn.ReLU(inplace=True)
180
        self.downsample = downsample
181

182
183
        if self.with_gcb:
            gcb_inplanes = planes * self.expansion
184
185
186
187
188
189
            self.context_block = ContextBlock(inplanes=gcb_inplanes, **gcb)

        # gen_attention
        if self.with_gen_attention:
            self.gen_attention_block = GeneralizedAttention(
                planes, **gen_attention)
Kai Chen's avatar
Kai Chen committed
190

ThangVu's avatar
ThangVu committed
191
192
193
194
195
196
197
198
199
200
201
202
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    @property
    def norm3(self):
        return getattr(self, self.norm3_name)

Kai Chen's avatar
Kai Chen committed
203
204
205
    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
206
            identity = x
Kai Chen's avatar
Kai Chen committed
207
208

            out = self.conv1(x)
ThangVu's avatar
ThangVu committed
209
            out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
210
211
            out = self.relu(out)

212
            out = self.conv2(out)
ThangVu's avatar
ThangVu committed
213
            out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
214
215
            out = self.relu(out)

216
217
218
            if self.with_gen_attention:
                out = self.gen_attention_block(out)

Kai Chen's avatar
Kai Chen committed
219
            out = self.conv3(out)
ThangVu's avatar
ThangVu committed
220
            out = self.norm3(out)
Kai Chen's avatar
Kai Chen committed
221

222
223
224
            if self.with_gcb:
                out = self.context_block(out)

Kai Chen's avatar
Kai Chen committed
225
            if self.downsample is not None:
pangjm's avatar
pangjm committed
226
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
227

pangjm's avatar
pangjm committed
228
            out += identity
Kai Chen's avatar
Kai Chen committed
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
248
                   style='pytorch',
249
                   with_cp=False,
250
                   conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
251
                   norm_cfg=dict(type='BN'),
252
                   dcn=None,
253
254
255
                   gcb=None,
                   gen_attention=None,
                   gen_attention_blocks=[]):
Kai Chen's avatar
Kai Chen committed
256
257
258
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
259
260
            build_conv_layer(
                conv_cfg,
Kai Chen's avatar
Kai Chen committed
261
262
263
264
265
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
Kai Chen's avatar
Kai Chen committed
266
            build_norm_layer(norm_cfg, planes * block.expansion)[1],
Kai Chen's avatar
Kai Chen committed
267
268
269
270
271
        )

    layers = []
    layers.append(
        block(
Kai Chen's avatar
Kai Chen committed
272
273
274
275
276
            inplanes=inplanes,
            planes=planes,
            stride=stride,
            dilation=dilation,
            downsample=downsample,
Kai Chen's avatar
Kai Chen committed
277
            style=style,
278
            with_cp=with_cp,
279
            conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
280
            norm_cfg=norm_cfg,
281
            dcn=dcn,
282
283
284
            gcb=gcb,
            gen_attention=gen_attention if
            (0 in gen_attention_blocks) else None))
Kai Chen's avatar
Kai Chen committed
285
286
287
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
Kai Chen's avatar
Kai Chen committed
288
            block(
Kai Chen's avatar
Kai Chen committed
289
290
291
292
                inplanes=inplanes,
                planes=planes,
                stride=1,
                dilation=dilation,
Kai Chen's avatar
Kai Chen committed
293
294
                style=style,
                with_cp=with_cp,
295
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
296
                norm_cfg=norm_cfg,
297
                dcn=dcn,
298
299
300
                gcb=gcb,
                gen_attention=gen_attention if
                (i in gen_attention_blocks) else None))
Kai Chen's avatar
Kai Chen committed
301
302
303
304

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
305
@BACKBONES.register_module
Kai Chen's avatar
Kai Chen committed
306
307
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
308

Kai Chen's avatar
Kai Chen committed
309
310
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
311
        in_channels (int): Number of input image channels. Normally 3.
Kai Chen's avatar
Kai Chen committed
312
313
314
315
316
317
318
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
319
320
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any parameters.
Kai Chen's avatar
Kai Chen committed
321
        norm_cfg (dict): dictionary to construct and config norm layer.
thangvu's avatar
thangvu committed
322
323
324
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only.
Kai Chen's avatar
Kai Chen committed
325
326
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
thangvu's avatar
thangvu committed
327
328
        zero_init_residual (bool): whether to use zero init for last norm layer
            in resblocks to let them behave as identity.
329
330
331
332
333
334
335
336
337
338
339
340
341
342

    Example:
        >>> from mmdet.models import ResNet
        >>> import torch
        >>> self = ResNet(depth=18)
        >>> self.eval()
        >>> inputs = torch.rand(1, 3, 32, 32)
        >>> level_outputs = self.forward(inputs)
        >>> for level_out in level_outputs:
        ...     print(tuple(level_out.shape))
        (1, 64, 8, 8)
        (1, 128, 4, 4)
        (1, 256, 2, 2)
        (1, 512, 1, 1)
Kai Chen's avatar
Kai Chen committed
343
    """
Kai Chen's avatar
Kai Chen committed
344

Kai Chen's avatar
Kai Chen committed
345
346
347
348
349
350
351
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
352
353

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
354
                 depth,
355
                 in_channels=3,
Kai Chen's avatar
Kai Chen committed
356
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
357
358
359
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
360
                 style='pytorch',
ThangVu's avatar
ThangVu committed
361
                 frozen_stages=-1,
362
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
363
                 norm_cfg=dict(type='BN', requires_grad=True),
thangvu's avatar
thangvu committed
364
                 norm_eval=True,
Kai Chen's avatar
Kai Chen committed
365
366
                 dcn=None,
                 stage_with_dcn=(False, False, False, False),
367
368
                 gcb=None,
                 stage_with_gcb=(False, False, False, False),
369
370
                 gen_attention=None,
                 stage_with_gen_attention=((), (), (), ()),
ThangVu's avatar
ThangVu committed
371
372
                 with_cp=False,
                 zero_init_residual=True):
Kai Chen's avatar
Kai Chen committed
373
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
374
375
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
376
377
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
378
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
379
380
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
381
        assert len(strides) == len(dilations) == num_stages
Kai Chen's avatar
Kai Chen committed
382
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
383
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
384
        self.style = style
ThangVu's avatar
ThangVu committed
385
        self.frozen_stages = frozen_stages
386
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
387
        self.norm_cfg = norm_cfg
ThangVu's avatar
ThangVu committed
388
        self.with_cp = with_cp
thangvu's avatar
thangvu committed
389
        self.norm_eval = norm_eval
Kai Chen's avatar
Kai Chen committed
390
391
        self.dcn = dcn
        self.stage_with_dcn = stage_with_dcn
Kai Chen's avatar
Kai Chen committed
392
393
        if dcn is not None:
            assert len(stage_with_dcn) == num_stages
394
        self.gen_attention = gen_attention
395
396
397
398
        self.gcb = gcb
        self.stage_with_gcb = stage_with_gcb
        if gcb is not None:
            assert len(stage_with_gcb) == num_stages
ThangVu's avatar
ThangVu committed
399
        self.zero_init_residual = zero_init_residual
pangjm's avatar
pangjm committed
400
401
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
402
        self.inplanes = 64
pangjm's avatar
pangjm committed
403

404
        self._make_stem_layer(in_channels)
Kai Chen's avatar
Kai Chen committed
405

Kai Chen's avatar
Kai Chen committed
406
        self.res_layers = []
pangjm's avatar
pangjm committed
407
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
408
409
            stride = strides[i]
            dilation = dilations[i]
Kai Chen's avatar
Kai Chen committed
410
            dcn = self.dcn if self.stage_with_dcn[i] else None
411
            gcb = self.gcb if self.stage_with_gcb[i] else None
Kai Chen's avatar
Kai Chen committed
412
413
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
414
                self.block,
Kai Chen's avatar
Kai Chen committed
415
416
417
418
419
420
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
421
                with_cp=with_cp,
422
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
423
                norm_cfg=norm_cfg,
424
                dcn=dcn,
425
426
427
                gcb=gcb,
                gen_attention=gen_attention,
                gen_attention_blocks=stage_with_gen_attention[i])
pangjm's avatar
pangjm committed
428
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
429
            layer_name = 'layer{}'.format(i + 1)
430
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
431
432
            self.res_layers.append(layer_name)

ThangVu's avatar
ThangVu committed
433
434
        self._freeze_stages()

pangjm's avatar
pangjm committed
435
436
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
437

ThangVu's avatar
ThangVu committed
438
439
440
441
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

442
    def _make_stem_layer(self, in_channels):
443
444
        self.conv1 = build_conv_layer(
            self.conv_cfg,
445
            in_channels,
446
447
448
449
450
            64,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False)
Kai Chen's avatar
Kai Chen committed
451
        self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
ThangVu's avatar
ThangVu committed
452
        self.add_module(self.norm1_name, norm1)
thangvu's avatar
thangvu committed
453
454
455
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

ThangVu's avatar
ThangVu committed
456
    def _freeze_stages(self):
ThangVu's avatar
ThangVu committed
457
        if self.frozen_stages >= 0:
Kai Chen's avatar
Kai Chen committed
458
            self.norm1.eval()
ThangVu's avatar
ThangVu committed
459
            for m in [self.conv1, self.norm1]:
ThangVu's avatar
ThangVu committed
460
                for param in m.parameters():
thangvu's avatar
thangvu committed
461
462
                    param.requires_grad = False

ThangVu's avatar
ThangVu committed
463
464
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
Kai Chen's avatar
Kai Chen committed
465
            m.eval()
ThangVu's avatar
ThangVu committed
466
467
468
            for param in m.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
469
470
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
471
472
            from mmdet.apis import get_root_logger
            logger = get_root_logger()
473
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
474
475
476
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
477
                    kaiming_init(m)
478
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
479
                    constant_init(m, 1)
480

Kai Chen's avatar
Kai Chen committed
481
482
483
484
485
486
            if self.dcn is not None:
                for m in self.modules():
                    if isinstance(m, Bottleneck) and hasattr(
                            m, 'conv2_offset'):
                        constant_init(m.conv2_offset, 0)

ThangVu's avatar
ThangVu committed
487
488
            if self.zero_init_residual:
                for m in self.modules():
ThangVu's avatar
ThangVu committed
489
490
491
492
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
Kai Chen's avatar
Kai Chen committed
493
494
495
496
497
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
ThangVu's avatar
ThangVu committed
498
        x = self.norm1(x)
Kai Chen's avatar
Kai Chen committed
499
500
501
502
503
504
505
506
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
myownskyW7's avatar
myownskyW7 committed
507
        return tuple(outs)
Kai Chen's avatar
Kai Chen committed
508
509
510

    def train(self, mode=True):
        super(ResNet, self).train(mode)
511
        self._freeze_stages()
thangvu's avatar
thangvu committed
512
        if mode and self.norm_eval:
ThangVu's avatar
ThangVu committed
513
            for m in self.modules():
thangvu's avatar
thangvu committed
514
                # trick: eval have effect on BatchNorm only
515
                if isinstance(m, _BatchNorm):
ThangVu's avatar
ThangVu committed
516
                    m.eval()