resnet.py 11.9 KB
Newer Older
1
2
import logging

Kai Chen's avatar
Kai Chen committed
3
4
import torch.nn as nn
import torch.utils.checkpoint as cp
Kai Chen's avatar
Kai Chen committed
5
6

from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
7
from mmcv.runner import load_checkpoint
8
from ..utils import build_norm_layer
Kai Chen's avatar
Kai Chen committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


def conv3x3(in_planes, out_planes, stride=1, dilation=1):
    "3x3 convolution with padding"
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        dilation=dilation,
        bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
32
                 style='pytorch',
33
                 with_cp=False,
ThangVu's avatar
ThangVu committed
34
                 normalize=dict(type='BN')):
Kai Chen's avatar
Kai Chen committed
35
36
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride, dilation)
37

ThangVu's avatar
ThangVu committed
38
39
        self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
ThangVu's avatar
ThangVu committed
40
41
        self.add_module(self.norm1_name, norm1)
        self.add_module(self.norm2_name, norm2)
42

Kai Chen's avatar
Kai Chen committed
43
44
45
46
47
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
48
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
49

ThangVu's avatar
ThangVu committed
50
51
52
53
54
55
56
57
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

Kai Chen's avatar
Kai Chen committed
58
    def forward(self, x):
pangjm's avatar
pangjm committed
59
        identity = x
Kai Chen's avatar
Kai Chen committed
60
61

        out = self.conv1(x)
ThangVu's avatar
ThangVu committed
62
        out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
63
64
65
        out = self.relu(out)

        out = self.conv2(out)
ThangVu's avatar
ThangVu committed
66
        out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
67
68

        if self.downsample is not None:
pangjm's avatar
pangjm committed
69
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
70

pangjm's avatar
pangjm committed
71
        out += identity
Kai Chen's avatar
Kai Chen committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
86
                 style='pytorch',
87
                 with_cp=False,
ThangVu's avatar
ThangVu committed
88
                 normalize=dict(type='BN')):
pangjm's avatar
pangjm committed
89
        """Bottleneck block for ResNet.
90
91
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
92
93
        """
        super(Bottleneck, self).__init__()
94
        assert style in ['pytorch', 'caffe']
pangjm's avatar
pangjm committed
95
96
        self.inplanes = inplanes
        self.planes = planes
97
        if style == 'pytorch':
pangjm's avatar
pangjm committed
98
99
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
100
        else:
pangjm's avatar
pangjm committed
101
102
            self.conv1_stride = stride
            self.conv2_stride = 1
Kai Chen's avatar
Kai Chen committed
103
        self.conv1 = nn.Conv2d(
pangjm's avatar
pangjm committed
104
105
106
107
108
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
Kai Chen's avatar
Kai Chen committed
109
110
111
112
        self.conv2 = nn.Conv2d(
            planes,
            planes,
            kernel_size=3,
pangjm's avatar
pangjm committed
113
            stride=self.conv2_stride,
Kai Chen's avatar
Kai Chen committed
114
115
116
117
            padding=dilation,
            dilation=dilation,
            bias=False)

ThangVu's avatar
ThangVu committed
118
119
120
121
122
123
124
125
        self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
        self.norm3_name, norm3 = build_norm_layer(normalize,
                                                  planes * self.expansion,
                                                  postfix=3)
        self.add_module(self.norm1_name, norm1)
        self.add_module(self.norm2_name, norm2)
        self.add_module(self.norm3_name, norm3)
126

Kai Chen's avatar
Kai Chen committed
127
128
129
130
131
132
133
        self.conv3 = nn.Conv2d(
            planes, planes * self.expansion, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
        self.with_cp = with_cp
134
        self.normalize = normalize
Kai Chen's avatar
Kai Chen committed
135

ThangVu's avatar
ThangVu committed
136
137
138
139
140
141
142
143
144
145
146
147
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    @property
    def norm3(self):
        return getattr(self, self.norm3_name)

Kai Chen's avatar
Kai Chen committed
148
149
150
    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
151
            identity = x
Kai Chen's avatar
Kai Chen committed
152
153

            out = self.conv1(x)
ThangVu's avatar
ThangVu committed
154
            out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
155
156
157
            out = self.relu(out)

            out = self.conv2(out)
ThangVu's avatar
ThangVu committed
158
            out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
159
160
161
            out = self.relu(out)

            out = self.conv3(out)
ThangVu's avatar
ThangVu committed
162
            out = self.norm3(out)
Kai Chen's avatar
Kai Chen committed
163
164

            if self.downsample is not None:
pangjm's avatar
pangjm committed
165
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
166

pangjm's avatar
pangjm committed
167
            out += identity
Kai Chen's avatar
Kai Chen committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
187
                   style='pytorch',
188
                   with_cp=False,
ThangVu's avatar
ThangVu committed
189
                   normalize=dict(type='BN')):
Kai Chen's avatar
Kai Chen committed
190
191
192
193
194
195
196
197
198
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
ThangVu's avatar
ThangVu committed
199
            build_norm_layer(normalize, planes * block.expansion)[1],
Kai Chen's avatar
Kai Chen committed
200
201
202
203
204
205
206
207
208
209
210
        )

    layers = []
    layers.append(
        block(
            inplanes,
            planes,
            stride,
            dilation,
            downsample,
            style=style,
211
            with_cp=with_cp,
ThangVu's avatar
ThangVu committed
212
            normalize=normalize))
Kai Chen's avatar
Kai Chen committed
213
214
215
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
216
217
            block(inplanes, planes, 1, dilation, style=style,
                  with_cp=with_cp, normalize=normalize))
Kai Chen's avatar
Kai Chen committed
218
219
220
221

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
222
223
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
224

Kai Chen's avatar
Kai Chen committed
225
226
227
228
229
230
231
232
233
234
235
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
            not freezing any parameters.
thangvu's avatar
thangvu committed
236
237
238
        normalize (dict): dictionary to construct norm layer. Additionally,
            eval mode and gradent freezing are controlled by
            eval (bool) and frozen (bool) respectively.
Kai Chen's avatar
Kai Chen committed
239
240
241
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
    """
Kai Chen's avatar
Kai Chen committed
242

Kai Chen's avatar
Kai Chen committed
243
244
245
246
247
248
249
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
250
251

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
252
253
                 depth,
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
254
255
256
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
257
                 style='pytorch',
ThangVu's avatar
ThangVu committed
258
                 frozen_stages=-1,
259
260
                 normalize=dict(
                     type='BN',
ThangVu's avatar
ThangVu committed
261
                     eval_mode=True,
thangvu's avatar
thangvu committed
262
                     frozen=False),
ThangVu's avatar
ThangVu committed
263
264
                 with_cp=False,
                 zero_init_residual=True):
Kai Chen's avatar
Kai Chen committed
265
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
266
267
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
268
269
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
270
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
271
272
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
273
        assert len(strides) == len(dilations) == num_stages
Kai Chen's avatar
Kai Chen committed
274
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
275
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
276
        self.style = style
ThangVu's avatar
ThangVu committed
277
278
        self.frozen_stages = frozen_stages
        assert (isinstance(normalize, dict) and 'eval_mode' in normalize
thangvu's avatar
thangvu committed
279
                and 'frozen' in normalize)
ThangVu's avatar
ThangVu committed
280
        self.norm_eval = normalize.pop('eval_mode')
281
        self.normalize = normalize
ThangVu's avatar
ThangVu committed
282
283
        self.with_cp = with_cp
        self.zero_init_residual = zero_init_residual
pangjm's avatar
pangjm committed
284
285
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
286
        self.inplanes = 64
pangjm's avatar
pangjm committed
287

thangvu's avatar
thangvu committed
288
        self._make_stem_layer()
Kai Chen's avatar
Kai Chen committed
289

Kai Chen's avatar
Kai Chen committed
290
        self.res_layers = []
pangjm's avatar
pangjm committed
291
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
292
293
294
295
            stride = strides[i]
            dilation = dilations[i]
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
296
                self.block,
Kai Chen's avatar
Kai Chen committed
297
298
299
300
301
302
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
303
                with_cp=with_cp,
ThangVu's avatar
ThangVu committed
304
                normalize=normalize)
pangjm's avatar
pangjm committed
305
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
306
            layer_name = 'layer{}'.format(i + 1)
307
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
308
309
            self.res_layers.append(layer_name)

ThangVu's avatar
ThangVu committed
310
311
        self._freeze_stages()

pangjm's avatar
pangjm committed
312
313
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
314

ThangVu's avatar
ThangVu committed
315
316
317
318
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

thangvu's avatar
thangvu committed
319
320
321
    def _make_stem_layer(self):
        self.conv1 = nn.Conv2d(
            3, 64, kernel_size=7, stride=2, padding=3, bias=False)
ThangVu's avatar
ThangVu committed
322
323
324
        self.norm1_name, norm1 = build_norm_layer(self.normalize,
                                                  64, postfix=1)
        self.add_module(self.norm1_name, norm1)
thangvu's avatar
thangvu committed
325
326
327
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

ThangVu's avatar
ThangVu committed
328
    def _freeze_stages(self):
ThangVu's avatar
ThangVu committed
329
        if self.frozen_stages >= 0:
ThangVu's avatar
ThangVu committed
330
            for m in [self.conv1, self.norm1]:
ThangVu's avatar
ThangVu committed
331
                for param in m.parameters():
thangvu's avatar
thangvu committed
332
333
                    param.requires_grad = False

ThangVu's avatar
ThangVu committed
334
335
336
337
338
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
            for param in m.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
339
340
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
341
342
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
343
344
345
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
346
                    kaiming_init(m)
ThangVu's avatar
minor  
ThangVu committed
347
                elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
348
                    constant_init(m, 1)
349
350

            # zero init for last norm layer https://arxiv.org/abs/1706.02677
ThangVu's avatar
ThangVu committed
351
352
            if self.zero_init_residual:
                for m in self.modules():
ThangVu's avatar
ThangVu committed
353
354
355
356
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
Kai Chen's avatar
Kai Chen committed
357
358
359
360
361
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
ThangVu's avatar
ThangVu committed
362
        x = self.norm1(x)
Kai Chen's avatar
Kai Chen committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
        if len(outs) == 1:
            return outs[0]
        else:
            return tuple(outs)

    def train(self, mode=True):
        super(ResNet, self).train(mode)
thangvu's avatar
thangvu committed
378
        if mode and self.norm_eval:
ThangVu's avatar
ThangVu committed
379
            for m in self.modules():
thangvu's avatar
thangvu committed
380
                # trick: eval have effect on BatchNorm only
ThangVu's avatar
ThangVu committed
381
382
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()