resnet.py 12 KB
Newer Older
1
2
import logging

Kai Chen's avatar
Kai Chen committed
3
4
import torch.nn as nn
import torch.utils.checkpoint as cp
Kai Chen's avatar
Kai Chen committed
5
6

from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
7
from mmcv.runner import load_checkpoint
8
from ..utils import build_norm_layer
Kai Chen's avatar
Kai Chen committed
9

Kai Chen's avatar
Kai Chen committed
10
11
from ..registry import BACKBONES

Kai Chen's avatar
Kai Chen committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

def conv3x3(in_planes, out_planes, stride=1, dilation=1):
    "3x3 convolution with padding"
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        dilation=dilation,
        bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
34
                 style='pytorch',
35
                 with_cp=False,
ThangVu's avatar
ThangVu committed
36
                 normalize=dict(type='BN')):
Kai Chen's avatar
Kai Chen committed
37
        super(BasicBlock, self).__init__()
38

ThangVu's avatar
ThangVu committed
39
40
        self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
41

Kai Chen's avatar
Kai Chen committed
42
        self.conv1 = conv3x3(inplanes, planes, stride, dilation)
ThangVu's avatar
ThangVu committed
43
        self.add_module(self.norm1_name, norm1)
Kai Chen's avatar
Kai Chen committed
44
        self.conv2 = conv3x3(planes, planes)
ThangVu's avatar
ThangVu committed
45
        self.add_module(self.norm2_name, norm2)
46

Kai Chen's avatar
Kai Chen committed
47
48
49
50
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
51
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
52

ThangVu's avatar
ThangVu committed
53
54
55
56
57
58
59
60
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

Kai Chen's avatar
Kai Chen committed
61
    def forward(self, x):
pangjm's avatar
pangjm committed
62
        identity = x
Kai Chen's avatar
Kai Chen committed
63
64

        out = self.conv1(x)
ThangVu's avatar
ThangVu committed
65
        out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
66
67
68
        out = self.relu(out)

        out = self.conv2(out)
ThangVu's avatar
ThangVu committed
69
        out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
70
71

        if self.downsample is not None:
pangjm's avatar
pangjm committed
72
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
73

pangjm's avatar
pangjm committed
74
        out += identity
Kai Chen's avatar
Kai Chen committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
89
                 style='pytorch',
90
                 with_cp=False,
ThangVu's avatar
ThangVu committed
91
                 normalize=dict(type='BN')):
pangjm's avatar
pangjm committed
92
        """Bottleneck block for ResNet.
93
94
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
95
96
        """
        super(Bottleneck, self).__init__()
97
        assert style in ['pytorch', 'caffe']
pangjm's avatar
pangjm committed
98
99
        self.inplanes = inplanes
        self.planes = planes
ThangVu's avatar
ThangVu committed
100
        self.normalize = normalize
101
        if style == 'pytorch':
pangjm's avatar
pangjm committed
102
103
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
104
        else:
pangjm's avatar
pangjm committed
105
106
            self.conv1_stride = stride
            self.conv2_stride = 1
107
108
109
110
111
112
113

        self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
        self.norm3_name, norm3 = build_norm_layer(normalize,
                                                  planes * self.expansion,
                                                  postfix=3)

Kai Chen's avatar
Kai Chen committed
114
        self.conv1 = nn.Conv2d(
pangjm's avatar
pangjm committed
115
116
117
118
119
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
120
        self.add_module(self.norm1_name, norm1)
Kai Chen's avatar
Kai Chen committed
121
122
123
124
        self.conv2 = nn.Conv2d(
            planes,
            planes,
            kernel_size=3,
pangjm's avatar
pangjm committed
125
            stride=self.conv2_stride,
Kai Chen's avatar
Kai Chen committed
126
127
128
            padding=dilation,
            dilation=dilation,
            bias=False)
ThangVu's avatar
ThangVu committed
129
        self.add_module(self.norm2_name, norm2)
Kai Chen's avatar
Kai Chen committed
130
131
        self.conv3 = nn.Conv2d(
            planes, planes * self.expansion, kernel_size=1, bias=False)
132
133
        self.add_module(self.norm3_name, norm3)

Kai Chen's avatar
Kai Chen committed
134
135
136
137
138
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
        self.with_cp = with_cp
139
        self.normalize = normalize
Kai Chen's avatar
Kai Chen committed
140

ThangVu's avatar
ThangVu committed
141
142
143
144
145
146
147
148
149
150
151
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    @property
    def norm3(self):
        return getattr(self, self.norm3_name)
Kai Chen's avatar
Kai Chen committed
152
153
154
155

    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
156
            identity = x
Kai Chen's avatar
Kai Chen committed
157
158

            out = self.conv1(x)
ThangVu's avatar
ThangVu committed
159
            out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
160
161
162
            out = self.relu(out)

            out = self.conv2(out)
ThangVu's avatar
ThangVu committed
163
            out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
164
165
166
            out = self.relu(out)

            out = self.conv3(out)
ThangVu's avatar
ThangVu committed
167
            out = self.norm3(out)
Kai Chen's avatar
Kai Chen committed
168
169

            if self.downsample is not None:
pangjm's avatar
pangjm committed
170
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
171

pangjm's avatar
pangjm committed
172
            out += identity
Kai Chen's avatar
Kai Chen committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
192
                   style='pytorch',
193
                   with_cp=False,
ThangVu's avatar
ThangVu committed
194
                   normalize=dict(type='BN')):
Kai Chen's avatar
Kai Chen committed
195
196
197
198
199
200
201
202
203
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
ThangVu's avatar
ThangVu committed
204
            build_norm_layer(normalize, planes * block.expansion)[1],
Kai Chen's avatar
Kai Chen committed
205
206
207
208
209
210
211
212
213
214
215
        )

    layers = []
    layers.append(
        block(
            inplanes,
            planes,
            stride,
            dilation,
            downsample,
            style=style,
216
            with_cp=with_cp,
ThangVu's avatar
ThangVu committed
217
            normalize=normalize))
Kai Chen's avatar
Kai Chen committed
218
219
220
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
221
222
            block(inplanes, planes, 1, dilation, style=style,
                  with_cp=with_cp, normalize=normalize))
Kai Chen's avatar
Kai Chen committed
223
224
225
226

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
227
@BACKBONES.register_module
Kai Chen's avatar
Kai Chen committed
228
229
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
230

Kai Chen's avatar
Kai Chen committed
231
232
233
234
235
236
237
238
239
240
241
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
            not freezing any parameters.
thangvu's avatar
thangvu committed
242
243
244
245
        normalize (dict): dictionary to construct and config norm layer.
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only.
Kai Chen's avatar
Kai Chen committed
246
247
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
thangvu's avatar
thangvu committed
248
249
        zero_init_residual (bool): whether to use zero init for last norm layer
            in resblocks to let them behave as identity.
Kai Chen's avatar
Kai Chen committed
250
    """
Kai Chen's avatar
Kai Chen committed
251

Kai Chen's avatar
Kai Chen committed
252
253
254
255
256
257
258
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
259
260

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
261
262
                 depth,
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
263
264
265
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
266
                 style='pytorch',
Kai Chen's avatar
Kai Chen committed
267
                 frozen_stages=-1,
268
269
                 normalize=dict(
                     type='BN',
thangvu's avatar
thangvu committed
270
                     frozen=False),
thangvu's avatar
thangvu committed
271
                 norm_eval=True,
ThangVu's avatar
ThangVu committed
272
273
                 with_cp=False,
                 zero_init_residual=True):
Kai Chen's avatar
Kai Chen committed
274
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
275
276
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
277
278
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
279
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
280
281
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
282
        assert len(strides) == len(dilations) == num_stages
Kai Chen's avatar
Kai Chen committed
283
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
284
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
285
        self.style = style
Kai Chen's avatar
Kai Chen committed
286
        self.frozen_stages = frozen_stages
287
        self.normalize = normalize
Kai Chen's avatar
Kai Chen committed
288
        self.with_cp = with_cp
thangvu's avatar
thangvu committed
289
        self.norm_eval = norm_eval
ThangVu's avatar
ThangVu committed
290
        self.zero_init_residual = zero_init_residual
pangjm's avatar
pangjm committed
291
292
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
293
        self.inplanes = 64
pangjm's avatar
pangjm committed
294

thangvu's avatar
thangvu committed
295
        self._make_stem_layer()
Kai Chen's avatar
Kai Chen committed
296

Kai Chen's avatar
Kai Chen committed
297
        self.res_layers = []
pangjm's avatar
pangjm committed
298
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
299
300
301
302
            stride = strides[i]
            dilation = dilations[i]
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
303
                self.block,
Kai Chen's avatar
Kai Chen committed
304
305
306
307
308
309
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
310
                with_cp=with_cp,
ThangVu's avatar
ThangVu committed
311
                normalize=normalize)
pangjm's avatar
pangjm committed
312
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
313
            layer_name = 'layer{}'.format(i + 1)
314
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
315
316
            self.res_layers.append(layer_name)

ThangVu's avatar
ThangVu committed
317
318
        self._freeze_stages()

pangjm's avatar
pangjm committed
319
320
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
321

ThangVu's avatar
ThangVu committed
322
323
324
325
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

thangvu's avatar
thangvu committed
326
327
328
    def _make_stem_layer(self):
        self.conv1 = nn.Conv2d(
            3, 64, kernel_size=7, stride=2, padding=3, bias=False)
ThangVu's avatar
ThangVu committed
329
330
331
        self.norm1_name, norm1 = build_norm_layer(self.normalize,
                                                  64, postfix=1)
        self.add_module(self.norm1_name, norm1)
thangvu's avatar
thangvu committed
332
333
334
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

ThangVu's avatar
ThangVu committed
335
    def _freeze_stages(self):
ThangVu's avatar
ThangVu committed
336
        if self.frozen_stages >= 0:
ThangVu's avatar
ThangVu committed
337
            for m in [self.conv1, self.norm1]:
ThangVu's avatar
ThangVu committed
338
                for param in m.parameters():
thangvu's avatar
thangvu committed
339
340
                    param.requires_grad = False

ThangVu's avatar
ThangVu committed
341
342
343
344
345
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
            for param in m.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
346
347
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
348
349
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
350
351
352
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
353
                    kaiming_init(m)
ThangVu's avatar
minor  
ThangVu committed
354
                elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
355
                    constant_init(m, 1)
356

ThangVu's avatar
ThangVu committed
357
358
            if self.zero_init_residual:
                for m in self.modules():
ThangVu's avatar
ThangVu committed
359
360
361
362
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
Kai Chen's avatar
Kai Chen committed
363
364
365
366
367
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
ThangVu's avatar
ThangVu committed
368
        x = self.norm1(x)
Kai Chen's avatar
Kai Chen committed
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
        if len(outs) == 1:
            return outs[0]
        else:
            return tuple(outs)

    def train(self, mode=True):
        super(ResNet, self).train(mode)
thangvu's avatar
thangvu committed
384
        if mode and self.norm_eval:
Kai Chen's avatar
Kai Chen committed
385
            for m in self.modules():
thangvu's avatar
thangvu committed
386
                # trick: eval have effect on BatchNorm only
Kai Chen's avatar
Kai Chen committed
387
388
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()