resnet.py 12 KB
Newer Older
1
2
import logging

Kai Chen's avatar
Kai Chen committed
3
4
import torch.nn as nn
import torch.utils.checkpoint as cp
Kai Chen's avatar
Kai Chen committed
5
6

from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
7
from mmcv.runner import load_checkpoint
8
from ..utils import build_norm_layer
Kai Chen's avatar
Kai Chen committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


def conv3x3(in_planes, out_planes, stride=1, dilation=1):
    "3x3 convolution with padding"
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        dilation=dilation,
        bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
32
                 style='pytorch',
33
                 with_cp=False,
ThangVu's avatar
ThangVu committed
34
                 normalize=dict(type='BN')):
Kai Chen's avatar
Kai Chen committed
35
36
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride, dilation)
37

ThangVu's avatar
ThangVu committed
38
39
        self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
ThangVu's avatar
ThangVu committed
40
41
        self.add_module(self.norm1_name, norm1)
        self.add_module(self.norm2_name, norm2)
42

Kai Chen's avatar
Kai Chen committed
43
44
45
46
47
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
48
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
49

ThangVu's avatar
ThangVu committed
50
51
52
53
54
55
56
57
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

Kai Chen's avatar
Kai Chen committed
58
    def forward(self, x):
pangjm's avatar
pangjm committed
59
        identity = x
Kai Chen's avatar
Kai Chen committed
60
61

        out = self.conv1(x)
ThangVu's avatar
ThangVu committed
62
        out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
63
64
65
        out = self.relu(out)

        out = self.conv2(out)
ThangVu's avatar
ThangVu committed
66
        out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
67
68

        if self.downsample is not None:
pangjm's avatar
pangjm committed
69
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
70

pangjm's avatar
pangjm committed
71
        out += identity
Kai Chen's avatar
Kai Chen committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
86
                 style='pytorch',
87
                 with_cp=False,
ThangVu's avatar
ThangVu committed
88
                 normalize=dict(type='BN')):
pangjm's avatar
pangjm committed
89
        """Bottleneck block for ResNet.
90
91
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
92
93
        """
        super(Bottleneck, self).__init__()
94
        assert style in ['pytorch', 'caffe']
pangjm's avatar
pangjm committed
95
96
        self.inplanes = inplanes
        self.planes = planes
ThangVu's avatar
ThangVu committed
97
        self.normalize = normalize
98
        if style == 'pytorch':
pangjm's avatar
pangjm committed
99
100
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
101
        else:
pangjm's avatar
pangjm committed
102
103
            self.conv1_stride = stride
            self.conv2_stride = 1
Kai Chen's avatar
Kai Chen committed
104
        self.conv1 = nn.Conv2d(
pangjm's avatar
pangjm committed
105
106
107
108
109
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
Kai Chen's avatar
Kai Chen committed
110
111
112
113
        self.conv2 = nn.Conv2d(
            planes,
            planes,
            kernel_size=3,
pangjm's avatar
pangjm committed
114
            stride=self.conv2_stride,
Kai Chen's avatar
Kai Chen committed
115
116
117
118
            padding=dilation,
            dilation=dilation,
            bias=False)

ThangVu's avatar
ThangVu committed
119
120
121
122
123
124
125
126
        self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
        self.norm3_name, norm3 = build_norm_layer(normalize,
                                                  planes * self.expansion,
                                                  postfix=3)
        self.add_module(self.norm1_name, norm1)
        self.add_module(self.norm2_name, norm2)
        self.add_module(self.norm3_name, norm3)
127

Kai Chen's avatar
Kai Chen committed
128
129
130
131
132
133
134
        self.conv3 = nn.Conv2d(
            planes, planes * self.expansion, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
        self.with_cp = with_cp
135
        self.normalize = normalize
Kai Chen's avatar
Kai Chen committed
136

ThangVu's avatar
ThangVu committed
137
138
139
140
141
142
143
144
145
146
147
148
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    @property
    def norm3(self):
        return getattr(self, self.norm3_name)

Kai Chen's avatar
Kai Chen committed
149
150
151
    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
152
            identity = x
Kai Chen's avatar
Kai Chen committed
153
154

            out = self.conv1(x)
ThangVu's avatar
ThangVu committed
155
            out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
156
157
158
            out = self.relu(out)

            out = self.conv2(out)
ThangVu's avatar
ThangVu committed
159
            out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
160
161
162
            out = self.relu(out)

            out = self.conv3(out)
ThangVu's avatar
ThangVu committed
163
            out = self.norm3(out)
Kai Chen's avatar
Kai Chen committed
164
165

            if self.downsample is not None:
pangjm's avatar
pangjm committed
166
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
167

pangjm's avatar
pangjm committed
168
            out += identity
Kai Chen's avatar
Kai Chen committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
188
                   style='pytorch',
189
                   with_cp=False,
ThangVu's avatar
ThangVu committed
190
                   normalize=dict(type='BN')):
Kai Chen's avatar
Kai Chen committed
191
192
193
194
195
196
197
198
199
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
ThangVu's avatar
ThangVu committed
200
            build_norm_layer(normalize, planes * block.expansion)[1],
Kai Chen's avatar
Kai Chen committed
201
202
203
204
205
206
207
208
209
210
211
        )

    layers = []
    layers.append(
        block(
            inplanes,
            planes,
            stride,
            dilation,
            downsample,
            style=style,
212
            with_cp=with_cp,
ThangVu's avatar
ThangVu committed
213
            normalize=normalize))
Kai Chen's avatar
Kai Chen committed
214
215
216
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
217
218
            block(inplanes, planes, 1, dilation, style=style,
                  with_cp=with_cp, normalize=normalize))
Kai Chen's avatar
Kai Chen committed
219
220
221
222

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
223
224
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
225

Kai Chen's avatar
Kai Chen committed
226
227
228
229
230
231
232
233
234
235
236
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
            not freezing any parameters.
thangvu's avatar
thangvu committed
237
238
239
        normalize (dict): dictionary to construct norm layer. Additionally,
            eval mode and gradent freezing are controlled by
            eval (bool) and frozen (bool) respectively.
Kai Chen's avatar
Kai Chen committed
240
241
242
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
    """
Kai Chen's avatar
Kai Chen committed
243

Kai Chen's avatar
Kai Chen committed
244
245
246
247
248
249
250
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
251
252

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
253
254
                 depth,
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
255
256
257
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
258
                 style='pytorch',
ThangVu's avatar
ThangVu committed
259
                 frozen_stages=-1,
260
261
                 normalize=dict(
                     type='BN',
ThangVu's avatar
ThangVu committed
262
                     eval_mode=True,
thangvu's avatar
thangvu committed
263
                     frozen=False),
ThangVu's avatar
ThangVu committed
264
265
                 with_cp=False,
                 zero_init_residual=True):
Kai Chen's avatar
Kai Chen committed
266
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
267
268
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
269
270
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
271
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
272
273
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
274
        assert len(strides) == len(dilations) == num_stages
Kai Chen's avatar
Kai Chen committed
275
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
276
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
277
        self.style = style
ThangVu's avatar
ThangVu committed
278
279
        self.frozen_stages = frozen_stages
        assert (isinstance(normalize, dict) and 'eval_mode' in normalize
thangvu's avatar
thangvu committed
280
                and 'frozen' in normalize)
ThangVu's avatar
ThangVu committed
281
        self.norm_eval = normalize.pop('eval_mode')
282
        self.normalize = normalize
ThangVu's avatar
ThangVu committed
283
284
        self.with_cp = with_cp
        self.zero_init_residual = zero_init_residual
pangjm's avatar
pangjm committed
285
286
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
287
        self.inplanes = 64
pangjm's avatar
pangjm committed
288

thangvu's avatar
thangvu committed
289
        self._make_stem_layer()
Kai Chen's avatar
Kai Chen committed
290

Kai Chen's avatar
Kai Chen committed
291
        self.res_layers = []
pangjm's avatar
pangjm committed
292
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
293
294
295
296
            stride = strides[i]
            dilation = dilations[i]
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
297
                self.block,
Kai Chen's avatar
Kai Chen committed
298
299
300
301
302
303
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
304
                with_cp=with_cp,
ThangVu's avatar
ThangVu committed
305
                normalize=normalize)
pangjm's avatar
pangjm committed
306
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
307
            layer_name = 'layer{}'.format(i + 1)
308
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
309
310
            self.res_layers.append(layer_name)

ThangVu's avatar
ThangVu committed
311
312
        self._freeze_stages()

pangjm's avatar
pangjm committed
313
314
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
315

ThangVu's avatar
ThangVu committed
316
317
318
319
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

thangvu's avatar
thangvu committed
320
321
322
    def _make_stem_layer(self):
        self.conv1 = nn.Conv2d(
            3, 64, kernel_size=7, stride=2, padding=3, bias=False)
ThangVu's avatar
ThangVu committed
323
324
325
        self.norm1_name, norm1 = build_norm_layer(self.normalize,
                                                  64, postfix=1)
        self.add_module(self.norm1_name, norm1)
thangvu's avatar
thangvu committed
326
327
328
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

ThangVu's avatar
ThangVu committed
329
    def _freeze_stages(self):
ThangVu's avatar
ThangVu committed
330
        if self.frozen_stages >= 0:
ThangVu's avatar
ThangVu committed
331
            for m in [self.conv1, self.norm1]:
ThangVu's avatar
ThangVu committed
332
                for param in m.parameters():
thangvu's avatar
thangvu committed
333
334
                    param.requires_grad = False

ThangVu's avatar
ThangVu committed
335
336
337
338
339
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
            for param in m.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
340
341
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
342
343
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
344
345
346
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
347
                    kaiming_init(m)
ThangVu's avatar
minor  
ThangVu committed
348
                elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
349
                    constant_init(m, 1)
350
351

            # zero init for last norm layer https://arxiv.org/abs/1706.02677
ThangVu's avatar
ThangVu committed
352
353
            if self.zero_init_residual:
                for m in self.modules():
ThangVu's avatar
ThangVu committed
354
355
356
357
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
Kai Chen's avatar
Kai Chen committed
358
359
360
361
362
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
ThangVu's avatar
ThangVu committed
363
        x = self.norm1(x)
Kai Chen's avatar
Kai Chen committed
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
        if len(outs) == 1:
            return outs[0]
        else:
            return tuple(outs)

    def train(self, mode=True):
        super(ResNet, self).train(mode)
thangvu's avatar
thangvu committed
379
        if mode and self.norm_eval:
ThangVu's avatar
ThangVu committed
380
            for m in self.modules():
thangvu's avatar
thangvu committed
381
                # trick: eval have effect on BatchNorm only
ThangVu's avatar
ThangVu committed
382
383
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()