resnet.py 12 KB
Newer Older
1
2
import logging

Kai Chen's avatar
Kai Chen committed
3
4
import torch.nn as nn
import torch.utils.checkpoint as cp
Kai Chen's avatar
Kai Chen committed
5
6

from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
7
from mmcv.runner import load_checkpoint
8
from ..utils import build_norm_layer
Kai Chen's avatar
Kai Chen committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


def conv3x3(in_planes, out_planes, stride=1, dilation=1):
    "3x3 convolution with padding"
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        dilation=dilation,
        bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
32
                 style='pytorch',
33
                 with_cp=False,
ThangVu's avatar
ThangVu committed
34
                 normalize=dict(type='BN')):
Kai Chen's avatar
Kai Chen committed
35
        super(BasicBlock, self).__init__()
36

ThangVu's avatar
ThangVu committed
37
38
        self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
39
40

        self.conv1 = conv3x3(inplanes, planes, stride, dilation)
ThangVu's avatar
ThangVu committed
41
        self.add_module(self.norm1_name, norm1)
42
        self.conv2 = conv3x3(planes, planes)
ThangVu's avatar
ThangVu committed
43
        self.add_module(self.norm2_name, norm2)
44

Kai Chen's avatar
Kai Chen committed
45
46
47
48
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
49
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
50

ThangVu's avatar
ThangVu committed
51
52
53
54
55
56
57
58
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

Kai Chen's avatar
Kai Chen committed
59
    def forward(self, x):
pangjm's avatar
pangjm committed
60
        identity = x
Kai Chen's avatar
Kai Chen committed
61
62

        out = self.conv1(x)
ThangVu's avatar
ThangVu committed
63
        out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
64
65
66
        out = self.relu(out)

        out = self.conv2(out)
ThangVu's avatar
ThangVu committed
67
        out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
68
69

        if self.downsample is not None:
pangjm's avatar
pangjm committed
70
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
71

pangjm's avatar
pangjm committed
72
        out += identity
Kai Chen's avatar
Kai Chen committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
87
                 style='pytorch',
88
                 with_cp=False,
ThangVu's avatar
ThangVu committed
89
                 normalize=dict(type='BN')):
pangjm's avatar
pangjm committed
90
        """Bottleneck block for ResNet.
91
92
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
93
94
        """
        super(Bottleneck, self).__init__()
95
        assert style in ['pytorch', 'caffe']
pangjm's avatar
pangjm committed
96
97
        self.inplanes = inplanes
        self.planes = planes
ThangVu's avatar
ThangVu committed
98
        self.normalize = normalize
99
        if style == 'pytorch':
pangjm's avatar
pangjm committed
100
101
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
102
        else:
pangjm's avatar
pangjm committed
103
104
            self.conv1_stride = stride
            self.conv2_stride = 1
105
106
107
108
109
110
111

        self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
        self.norm3_name, norm3 = build_norm_layer(normalize,
                                                  planes * self.expansion,
                                                  postfix=3)

Kai Chen's avatar
Kai Chen committed
112
        self.conv1 = nn.Conv2d(
pangjm's avatar
pangjm committed
113
114
115
116
117
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
118
        self.add_module(self.norm1_name, norm1)
Kai Chen's avatar
Kai Chen committed
119
120
121
122
        self.conv2 = nn.Conv2d(
            planes,
            planes,
            kernel_size=3,
pangjm's avatar
pangjm committed
123
            stride=self.conv2_stride,
Kai Chen's avatar
Kai Chen committed
124
125
126
            padding=dilation,
            dilation=dilation,
            bias=False)
ThangVu's avatar
ThangVu committed
127
        self.add_module(self.norm2_name, norm2)
Kai Chen's avatar
Kai Chen committed
128
129
        self.conv3 = nn.Conv2d(
            planes, planes * self.expansion, kernel_size=1, bias=False)
130
131
        self.add_module(self.norm3_name, norm3)

Kai Chen's avatar
Kai Chen committed
132
133
134
135
136
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
        self.with_cp = with_cp
137
        self.normalize = normalize
Kai Chen's avatar
Kai Chen committed
138

ThangVu's avatar
ThangVu committed
139
140
141
142
143
144
145
146
147
148
149
150
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    @property
    def norm3(self):
        return getattr(self, self.norm3_name)

Kai Chen's avatar
Kai Chen committed
151
152
153
    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
154
            identity = x
Kai Chen's avatar
Kai Chen committed
155
156

            out = self.conv1(x)
ThangVu's avatar
ThangVu committed
157
            out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
158
159
160
            out = self.relu(out)

            out = self.conv2(out)
ThangVu's avatar
ThangVu committed
161
            out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
162
163
164
            out = self.relu(out)

            out = self.conv3(out)
ThangVu's avatar
ThangVu committed
165
            out = self.norm3(out)
Kai Chen's avatar
Kai Chen committed
166
167

            if self.downsample is not None:
pangjm's avatar
pangjm committed
168
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
169

pangjm's avatar
pangjm committed
170
            out += identity
Kai Chen's avatar
Kai Chen committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
190
                   style='pytorch',
191
                   with_cp=False,
ThangVu's avatar
ThangVu committed
192
                   normalize=dict(type='BN')):
Kai Chen's avatar
Kai Chen committed
193
194
195
196
197
198
199
200
201
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
ThangVu's avatar
ThangVu committed
202
            build_norm_layer(normalize, planes * block.expansion)[1],
Kai Chen's avatar
Kai Chen committed
203
204
205
206
207
208
209
210
211
212
213
        )

    layers = []
    layers.append(
        block(
            inplanes,
            planes,
            stride,
            dilation,
            downsample,
            style=style,
214
            with_cp=with_cp,
ThangVu's avatar
ThangVu committed
215
            normalize=normalize))
Kai Chen's avatar
Kai Chen committed
216
217
218
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
219
220
            block(inplanes, planes, 1, dilation, style=style,
                  with_cp=with_cp, normalize=normalize))
Kai Chen's avatar
Kai Chen committed
221
222
223
224

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
225
226
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
227

Kai Chen's avatar
Kai Chen committed
228
229
230
231
232
233
234
235
236
237
238
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
            not freezing any parameters.
thangvu's avatar
thangvu committed
239
240
241
242
        normalize (dict): dictionary to construct and config norm layer.
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only.
Kai Chen's avatar
Kai Chen committed
243
244
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
thangvu's avatar
thangvu committed
245
246
        zero_init_residual (bool): whether to use zero init for last norm layer
            in resblocks to let them behave as identity.
Kai Chen's avatar
Kai Chen committed
247
    """
Kai Chen's avatar
Kai Chen committed
248

Kai Chen's avatar
Kai Chen committed
249
250
251
252
253
254
255
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
256
257

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
258
259
                 depth,
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
260
261
262
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
263
                 style='pytorch',
ThangVu's avatar
ThangVu committed
264
                 frozen_stages=-1,
265
266
                 normalize=dict(
                     type='BN',
thangvu's avatar
thangvu committed
267
                     frozen=False),
thangvu's avatar
thangvu committed
268
                 norm_eval=True,
ThangVu's avatar
ThangVu committed
269
270
                 with_cp=False,
                 zero_init_residual=True):
Kai Chen's avatar
Kai Chen committed
271
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
272
273
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
274
275
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
276
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
277
278
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
279
        assert len(strides) == len(dilations) == num_stages
Kai Chen's avatar
Kai Chen committed
280
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
281
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
282
        self.style = style
ThangVu's avatar
ThangVu committed
283
        self.frozen_stages = frozen_stages
284
        self.normalize = normalize
ThangVu's avatar
ThangVu committed
285
        self.with_cp = with_cp
thangvu's avatar
thangvu committed
286
        self.norm_eval = norm_eval
ThangVu's avatar
ThangVu committed
287
        self.zero_init_residual = zero_init_residual
pangjm's avatar
pangjm committed
288
289
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
290
        self.inplanes = 64
pangjm's avatar
pangjm committed
291

thangvu's avatar
thangvu committed
292
        self._make_stem_layer()
Kai Chen's avatar
Kai Chen committed
293

Kai Chen's avatar
Kai Chen committed
294
        self.res_layers = []
pangjm's avatar
pangjm committed
295
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
296
297
298
299
            stride = strides[i]
            dilation = dilations[i]
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
300
                self.block,
Kai Chen's avatar
Kai Chen committed
301
302
303
304
305
306
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
307
                with_cp=with_cp,
ThangVu's avatar
ThangVu committed
308
                normalize=normalize)
pangjm's avatar
pangjm committed
309
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
310
            layer_name = 'layer{}'.format(i + 1)
311
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
312
313
            self.res_layers.append(layer_name)

ThangVu's avatar
ThangVu committed
314
315
        self._freeze_stages()

pangjm's avatar
pangjm committed
316
317
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
318

ThangVu's avatar
ThangVu committed
319
320
321
322
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

thangvu's avatar
thangvu committed
323
324
325
    def _make_stem_layer(self):
        self.conv1 = nn.Conv2d(
            3, 64, kernel_size=7, stride=2, padding=3, bias=False)
ThangVu's avatar
ThangVu committed
326
327
328
        self.norm1_name, norm1 = build_norm_layer(self.normalize,
                                                  64, postfix=1)
        self.add_module(self.norm1_name, norm1)
thangvu's avatar
thangvu committed
329
330
331
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

ThangVu's avatar
ThangVu committed
332
    def _freeze_stages(self):
ThangVu's avatar
ThangVu committed
333
        if self.frozen_stages >= 0:
ThangVu's avatar
ThangVu committed
334
            for m in [self.conv1, self.norm1]:
ThangVu's avatar
ThangVu committed
335
                for param in m.parameters():
thangvu's avatar
thangvu committed
336
337
                    param.requires_grad = False

ThangVu's avatar
ThangVu committed
338
339
340
341
342
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
            for param in m.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
343
344
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
345
346
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
347
348
349
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
350
                    kaiming_init(m)
ThangVu's avatar
minor  
ThangVu committed
351
                elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
352
                    constant_init(m, 1)
353

ThangVu's avatar
ThangVu committed
354
355
            if self.zero_init_residual:
                for m in self.modules():
ThangVu's avatar
ThangVu committed
356
357
358
359
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
Kai Chen's avatar
Kai Chen committed
360
361
362
363
364
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
ThangVu's avatar
ThangVu committed
365
        x = self.norm1(x)
Kai Chen's avatar
Kai Chen committed
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
        if len(outs) == 1:
            return outs[0]
        else:
            return tuple(outs)

    def train(self, mode=True):
        super(ResNet, self).train(mode)
thangvu's avatar
thangvu committed
381
        if mode and self.norm_eval:
ThangVu's avatar
ThangVu committed
382
            for m in self.modules():
thangvu's avatar
thangvu committed
383
                # trick: eval have effect on BatchNorm only
ThangVu's avatar
ThangVu committed
384
385
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()