resnet.py 17 KB
Newer Older
1
2
import logging

Kai Chen's avatar
Kai Chen committed
3
4
import torch.nn as nn
import torch.utils.checkpoint as cp
5
from torch.nn.modules.batchnorm import _BatchNorm
Kai Chen's avatar
Kai Chen committed
6
7

from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
8
from mmcv.runner import load_checkpoint
Kai Chen's avatar
Kai Chen committed
9

10
from mmdet.ops import DeformConv, ModulatedDeformConv, ContextBlock
11
12
from mmdet.models.plugins import GeneralizedAttention

Kai Chen's avatar
Kai Chen committed
13
from ..registry import BACKBONES
14
from ..utils import build_conv_layer, build_norm_layer
Kai Chen's avatar
Kai Chen committed
15
16
17
18
19
20
21
22
23
24
25


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
26
                 style='pytorch',
27
                 with_cp=False,
28
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
29
                 norm_cfg=dict(type='BN'),
30
                 dcn=None,
31
32
                 gcb=None,
                 gen_attention=None):
Kai Chen's avatar
Kai Chen committed
33
        super(BasicBlock, self).__init__()
pangjm's avatar
pangjm committed
34
        assert dcn is None, "Not implemented yet."
35
        assert gen_attention is None, "Not implemented yet."
36
        assert gcb is None, "Not implemented yet."
37

Kai Chen's avatar
Kai Chen committed
38
39
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
40

41
42
43
44
45
46
47
48
49
        self.conv1 = build_conv_layer(
            conv_cfg,
            inplanes,
            planes,
            3,
            stride=stride,
            padding=dilation,
            dilation=dilation,
            bias=False)
ThangVu's avatar
ThangVu committed
50
        self.add_module(self.norm1_name, norm1)
51
        self.conv2 = build_conv_layer(
52
            conv_cfg, planes, planes, 3, padding=1, bias=False)
ThangVu's avatar
ThangVu committed
53
        self.add_module(self.norm2_name, norm2)
54

Kai Chen's avatar
Kai Chen committed
55
56
57
58
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
59
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
60

ThangVu's avatar
ThangVu committed
61
62
63
64
65
66
67
68
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

Kai Chen's avatar
Kai Chen committed
69
    def forward(self, x):
pangjm's avatar
pangjm committed
70
        identity = x
Kai Chen's avatar
Kai Chen committed
71
72

        out = self.conv1(x)
ThangVu's avatar
ThangVu committed
73
        out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
74
75
76
        out = self.relu(out)

        out = self.conv2(out)
ThangVu's avatar
ThangVu committed
77
        out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
78
79

        if self.downsample is not None:
pangjm's avatar
pangjm committed
80
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
81

pangjm's avatar
pangjm committed
82
        out += identity
Kai Chen's avatar
Kai Chen committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
97
                 style='pytorch',
98
                 with_cp=False,
99
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
100
                 norm_cfg=dict(type='BN'),
101
                 dcn=None,
102
103
                 gcb=None,
                 gen_attention=None):
pangjm's avatar
pangjm committed
104
        """Bottleneck block for ResNet.
105
106
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
107
108
        """
        super(Bottleneck, self).__init__()
109
        assert style in ['pytorch', 'caffe']
Kai Chen's avatar
Kai Chen committed
110
        assert dcn is None or isinstance(dcn, dict)
111
        assert gcb is None or isinstance(gcb, dict)
112
113
        assert gen_attention is None or isinstance(gen_attention, dict)

pangjm's avatar
pangjm committed
114
115
        self.inplanes = inplanes
        self.planes = planes
Kai Chen's avatar
Kai Chen committed
116
117
118
119
        self.stride = stride
        self.dilation = dilation
        self.style = style
        self.with_cp = with_cp
120
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
121
        self.norm_cfg = norm_cfg
Kai Chen's avatar
Kai Chen committed
122
123
        self.dcn = dcn
        self.with_dcn = dcn is not None
124
125
        self.gcb = gcb
        self.with_gcb = gcb is not None
126
127
128
        self.gen_attention = gen_attention
        self.with_gen_attention = gen_attention is not None

Kai Chen's avatar
Kai Chen committed
129
        if self.style == 'pytorch':
pangjm's avatar
pangjm committed
130
131
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
132
        else:
pangjm's avatar
pangjm committed
133
134
            self.conv1_stride = stride
            self.conv2_stride = 1
135

Kai Chen's avatar
Kai Chen committed
136
137
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
Kai Chen's avatar
Kai Chen committed
138
        self.norm3_name, norm3 = build_norm_layer(
Kai Chen's avatar
Kai Chen committed
139
            norm_cfg, planes * self.expansion, postfix=3)
140

141
142
        self.conv1 = build_conv_layer(
            conv_cfg,
pangjm's avatar
pangjm committed
143
144
145
146
147
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
148
        self.add_module(self.norm1_name, norm1)
Kai Chen's avatar
Kai Chen committed
149
150
151
152
153
154
        fallback_on_stride = False
        self.with_modulated_dcn = False
        if self.with_dcn:
            fallback_on_stride = dcn.get('fallback_on_stride', False)
            self.with_modulated_dcn = dcn.get('modulated', False)
        if not self.with_dcn or fallback_on_stride:
155
156
            self.conv2 = build_conv_layer(
                conv_cfg,
Kai Chen's avatar
Kai Chen committed
157
158
159
160
161
162
163
164
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                bias=False)
        else:
165
            assert conv_cfg is None, 'conv_cfg must be None for DCN'
Kai Chen's avatar
Kai Chen committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
            deformable_groups = dcn.get('deformable_groups', 1)
            if not self.with_modulated_dcn:
                conv_op = DeformConv
                offset_channels = 18
            else:
                conv_op = ModulatedDeformConv
                offset_channels = 27
            self.conv2_offset = nn.Conv2d(
                planes,
                deformable_groups * offset_channels,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation)
            self.conv2 = conv_op(
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=dilation,
                dilation=dilation,
                deformable_groups=deformable_groups,
                bias=False)
ThangVu's avatar
ThangVu committed
189
        self.add_module(self.norm2_name, norm2)
190
191
192
193
194
195
        self.conv3 = build_conv_layer(
            conv_cfg,
            planes,
            planes * self.expansion,
            kernel_size=1,
            bias=False)
196
197
        self.add_module(self.norm3_name, norm3)

Kai Chen's avatar
Kai Chen committed
198
        self.relu = nn.ReLU(inplace=True)
199
        self.downsample = downsample
200

201
202
        if self.with_gcb:
            gcb_inplanes = planes * self.expansion
203
204
205
206
207
208
            self.context_block = ContextBlock(inplanes=gcb_inplanes, **gcb)

        # gen_attention
        if self.with_gen_attention:
            self.gen_attention_block = GeneralizedAttention(
                planes, **gen_attention)
Kai Chen's avatar
Kai Chen committed
209

ThangVu's avatar
ThangVu committed
210
211
212
213
214
215
216
217
218
219
220
221
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    @property
    def norm3(self):
        return getattr(self, self.norm3_name)

Kai Chen's avatar
Kai Chen committed
222
223
224
    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
225
            identity = x
Kai Chen's avatar
Kai Chen committed
226
227

            out = self.conv1(x)
ThangVu's avatar
ThangVu committed
228
            out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
229
230
            out = self.relu(out)

Kai Chen's avatar
Kai Chen committed
231
232
233
234
235
236
237
238
239
240
            if not self.with_dcn:
                out = self.conv2(out)
            elif self.with_modulated_dcn:
                offset_mask = self.conv2_offset(out)
                offset = offset_mask[:, :18, :, :]
                mask = offset_mask[:, -9:, :, :].sigmoid()
                out = self.conv2(out, offset, mask)
            else:
                offset = self.conv2_offset(out)
                out = self.conv2(out, offset)
ThangVu's avatar
ThangVu committed
241
            out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
242
243
            out = self.relu(out)

244
245
246
            if self.with_gen_attention:
                out = self.gen_attention_block(out)

Kai Chen's avatar
Kai Chen committed
247
            out = self.conv3(out)
ThangVu's avatar
ThangVu committed
248
            out = self.norm3(out)
Kai Chen's avatar
Kai Chen committed
249

250
251
252
            if self.with_gcb:
                out = self.context_block(out)

Kai Chen's avatar
Kai Chen committed
253
            if self.downsample is not None:
pangjm's avatar
pangjm committed
254
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
255

pangjm's avatar
pangjm committed
256
            out += identity
Kai Chen's avatar
Kai Chen committed
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
276
                   style='pytorch',
277
                   with_cp=False,
278
                   conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
279
                   norm_cfg=dict(type='BN'),
280
                   dcn=None,
281
282
283
                   gcb=None,
                   gen_attention=None,
                   gen_attention_blocks=[]):
Kai Chen's avatar
Kai Chen committed
284
285
286
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
287
288
            build_conv_layer(
                conv_cfg,
Kai Chen's avatar
Kai Chen committed
289
290
291
292
293
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
Kai Chen's avatar
Kai Chen committed
294
            build_norm_layer(norm_cfg, planes * block.expansion)[1],
Kai Chen's avatar
Kai Chen committed
295
296
297
298
299
300
301
302
303
304
305
        )

    layers = []
    layers.append(
        block(
            inplanes,
            planes,
            stride,
            dilation,
            downsample,
            style=style,
306
            with_cp=with_cp,
307
            conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
308
            norm_cfg=norm_cfg,
309
            dcn=dcn,
310
311
312
            gcb=gcb,
            gen_attention=gen_attention if
            (0 in gen_attention_blocks) else None))
Kai Chen's avatar
Kai Chen committed
313
314
315
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
Kai Chen's avatar
Kai Chen committed
316
317
318
319
320
321
322
            block(
                inplanes,
                planes,
                1,
                dilation,
                style=style,
                with_cp=with_cp,
323
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
324
                norm_cfg=norm_cfg,
325
                dcn=dcn,
326
327
328
                gcb=gcb,
                gen_attention=gen_attention if
                (i in gen_attention_blocks) else None))
Kai Chen's avatar
Kai Chen committed
329
330
331
332

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
333
@BACKBONES.register_module
Kai Chen's avatar
Kai Chen committed
334
335
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
336

Kai Chen's avatar
Kai Chen committed
337
338
339
340
341
342
343
344
345
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
346
347
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any parameters.
Kai Chen's avatar
Kai Chen committed
348
        norm_cfg (dict): dictionary to construct and config norm layer.
thangvu's avatar
thangvu committed
349
350
351
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only.
Kai Chen's avatar
Kai Chen committed
352
353
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
thangvu's avatar
thangvu committed
354
355
        zero_init_residual (bool): whether to use zero init for last norm layer
            in resblocks to let them behave as identity.
Kai Chen's avatar
Kai Chen committed
356
    """
Kai Chen's avatar
Kai Chen committed
357

Kai Chen's avatar
Kai Chen committed
358
359
360
361
362
363
364
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
365
366

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
367
368
                 depth,
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
369
370
371
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
372
                 style='pytorch',
ThangVu's avatar
ThangVu committed
373
                 frozen_stages=-1,
374
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
375
                 norm_cfg=dict(type='BN', requires_grad=True),
thangvu's avatar
thangvu committed
376
                 norm_eval=True,
Kai Chen's avatar
Kai Chen committed
377
378
                 dcn=None,
                 stage_with_dcn=(False, False, False, False),
379
380
                 gcb=None,
                 stage_with_gcb=(False, False, False, False),
381
382
                 gen_attention=None,
                 stage_with_gen_attention=((), (), (), ()),
ThangVu's avatar
ThangVu committed
383
384
                 with_cp=False,
                 zero_init_residual=True):
Kai Chen's avatar
Kai Chen committed
385
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
386
387
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
388
389
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
390
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
391
392
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
393
        assert len(strides) == len(dilations) == num_stages
Kai Chen's avatar
Kai Chen committed
394
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
395
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
396
        self.style = style
ThangVu's avatar
ThangVu committed
397
        self.frozen_stages = frozen_stages
398
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
399
        self.norm_cfg = norm_cfg
ThangVu's avatar
ThangVu committed
400
        self.with_cp = with_cp
thangvu's avatar
thangvu committed
401
        self.norm_eval = norm_eval
Kai Chen's avatar
Kai Chen committed
402
403
        self.dcn = dcn
        self.stage_with_dcn = stage_with_dcn
Kai Chen's avatar
Kai Chen committed
404
405
        if dcn is not None:
            assert len(stage_with_dcn) == num_stages
406
        self.gen_attention = gen_attention
407
408
409
410
        self.gcb = gcb
        self.stage_with_gcb = stage_with_gcb
        if gcb is not None:
            assert len(stage_with_gcb) == num_stages
ThangVu's avatar
ThangVu committed
411
        self.zero_init_residual = zero_init_residual
pangjm's avatar
pangjm committed
412
413
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
414
        self.inplanes = 64
pangjm's avatar
pangjm committed
415

thangvu's avatar
thangvu committed
416
        self._make_stem_layer()
Kai Chen's avatar
Kai Chen committed
417

Kai Chen's avatar
Kai Chen committed
418
        self.res_layers = []
pangjm's avatar
pangjm committed
419
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
420
421
            stride = strides[i]
            dilation = dilations[i]
Kai Chen's avatar
Kai Chen committed
422
            dcn = self.dcn if self.stage_with_dcn[i] else None
423
            gcb = self.gcb if self.stage_with_gcb[i] else None
Kai Chen's avatar
Kai Chen committed
424
425
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
426
                self.block,
Kai Chen's avatar
Kai Chen committed
427
428
429
430
431
432
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
433
                with_cp=with_cp,
434
                conv_cfg=conv_cfg,
Kai Chen's avatar
Kai Chen committed
435
                norm_cfg=norm_cfg,
436
                dcn=dcn,
437
438
439
                gcb=gcb,
                gen_attention=gen_attention,
                gen_attention_blocks=stage_with_gen_attention[i])
pangjm's avatar
pangjm committed
440
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
441
            layer_name = 'layer{}'.format(i + 1)
442
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
443
444
            self.res_layers.append(layer_name)

ThangVu's avatar
ThangVu committed
445
446
        self._freeze_stages()

pangjm's avatar
pangjm committed
447
448
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
449

ThangVu's avatar
ThangVu committed
450
451
452
453
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

thangvu's avatar
thangvu committed
454
    def _make_stem_layer(self):
455
456
457
458
459
460
461
462
        self.conv1 = build_conv_layer(
            self.conv_cfg,
            3,
            64,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False)
Kai Chen's avatar
Kai Chen committed
463
        self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
ThangVu's avatar
ThangVu committed
464
        self.add_module(self.norm1_name, norm1)
thangvu's avatar
thangvu committed
465
466
467
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

ThangVu's avatar
ThangVu committed
468
    def _freeze_stages(self):
ThangVu's avatar
ThangVu committed
469
        if self.frozen_stages >= 0:
Kai Chen's avatar
Kai Chen committed
470
            self.norm1.eval()
ThangVu's avatar
ThangVu committed
471
            for m in [self.conv1, self.norm1]:
ThangVu's avatar
ThangVu committed
472
                for param in m.parameters():
thangvu's avatar
thangvu committed
473
474
                    param.requires_grad = False

ThangVu's avatar
ThangVu committed
475
476
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
Kai Chen's avatar
Kai Chen committed
477
            m.eval()
ThangVu's avatar
ThangVu committed
478
479
480
            for param in m.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
481
482
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
483
484
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
485
486
487
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
488
                    kaiming_init(m)
489
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
490
                    constant_init(m, 1)
491

Kai Chen's avatar
Kai Chen committed
492
493
494
495
496
497
            if self.dcn is not None:
                for m in self.modules():
                    if isinstance(m, Bottleneck) and hasattr(
                            m, 'conv2_offset'):
                        constant_init(m.conv2_offset, 0)

ThangVu's avatar
ThangVu committed
498
499
            if self.zero_init_residual:
                for m in self.modules():
ThangVu's avatar
ThangVu committed
500
501
502
503
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
Kai Chen's avatar
Kai Chen committed
504
505
506
507
508
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
ThangVu's avatar
ThangVu committed
509
        x = self.norm1(x)
Kai Chen's avatar
Kai Chen committed
510
511
512
513
514
515
516
517
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
myownskyW7's avatar
myownskyW7 committed
518
        return tuple(outs)
Kai Chen's avatar
Kai Chen committed
519
520
521

    def train(self, mode=True):
        super(ResNet, self).train(mode)
522
        self._freeze_stages()
thangvu's avatar
thangvu committed
523
        if mode and self.norm_eval:
ThangVu's avatar
ThangVu committed
524
            for m in self.modules():
thangvu's avatar
thangvu committed
525
                # trick: eval have effect on BatchNorm only
526
                if isinstance(m, _BatchNorm):
ThangVu's avatar
ThangVu committed
527
                    m.eval()