resnet.py 12 KB
Newer Older
1
2
import logging

Kai Chen's avatar
Kai Chen committed
3
4
import torch.nn as nn
import torch.utils.checkpoint as cp
Kai Chen's avatar
Kai Chen committed
5
6

from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
7
from mmcv.runner import load_checkpoint
8
from ..utils import build_norm_layer
Kai Chen's avatar
Kai Chen committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


def conv3x3(in_planes, out_planes, stride=1, dilation=1):
    "3x3 convolution with padding"
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        dilation=dilation,
        bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
32
                 style='pytorch',
33
                 with_cp=False,
ThangVu's avatar
ThangVu committed
34
                 normalize=dict(type='BN')):
Kai Chen's avatar
Kai Chen committed
35
        super(BasicBlock, self).__init__()
36

ThangVu's avatar
ThangVu committed
37
38
        self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
39
40

        self.conv1 = conv3x3(inplanes, planes, stride, dilation)
ThangVu's avatar
ThangVu committed
41
        self.add_module(self.norm1_name, norm1)
42
        self.conv2 = conv3x3(planes, planes)
ThangVu's avatar
ThangVu committed
43
        self.add_module(self.norm2_name, norm2)
44

Kai Chen's avatar
Kai Chen committed
45
46
47
48
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
49
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
50

ThangVu's avatar
ThangVu committed
51
52
53
54
55
56
57
58
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

Kai Chen's avatar
Kai Chen committed
59
    def forward(self, x):
pangjm's avatar
pangjm committed
60
        identity = x
Kai Chen's avatar
Kai Chen committed
61
62

        out = self.conv1(x)
ThangVu's avatar
ThangVu committed
63
        out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
64
65
66
        out = self.relu(out)

        out = self.conv2(out)
ThangVu's avatar
ThangVu committed
67
        out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
68
69

        if self.downsample is not None:
pangjm's avatar
pangjm committed
70
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
71

pangjm's avatar
pangjm committed
72
        out += identity
Kai Chen's avatar
Kai Chen committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
87
                 style='pytorch',
88
                 with_cp=False,
ThangVu's avatar
ThangVu committed
89
                 normalize=dict(type='BN')):
pangjm's avatar
pangjm committed
90
        """Bottleneck block for ResNet.
91
92
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
93
94
        """
        super(Bottleneck, self).__init__()
95
        assert style in ['pytorch', 'caffe']
pangjm's avatar
pangjm committed
96
97
        self.inplanes = inplanes
        self.planes = planes
ThangVu's avatar
ThangVu committed
98
        self.normalize = normalize
99
        if style == 'pytorch':
pangjm's avatar
pangjm committed
100
101
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
102
        else:
pangjm's avatar
pangjm committed
103
104
            self.conv1_stride = stride
            self.conv2_stride = 1
105
106
107
108
109
110
111

        self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
        self.norm3_name, norm3 = build_norm_layer(normalize,
                                                  planes * self.expansion,
                                                  postfix=3)

Kai Chen's avatar
Kai Chen committed
112
        self.conv1 = nn.Conv2d(
pangjm's avatar
pangjm committed
113
114
115
116
117
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
118
        self.add_module(self.norm1_name, norm1)
Kai Chen's avatar
Kai Chen committed
119
120
121
122
        self.conv2 = nn.Conv2d(
            planes,
            planes,
            kernel_size=3,
pangjm's avatar
pangjm committed
123
            stride=self.conv2_stride,
Kai Chen's avatar
Kai Chen committed
124
125
126
            padding=dilation,
            dilation=dilation,
            bias=False)
ThangVu's avatar
ThangVu committed
127
        self.add_module(self.norm2_name, norm2)
Kai Chen's avatar
Kai Chen committed
128
129
        self.conv3 = nn.Conv2d(
            planes, planes * self.expansion, kernel_size=1, bias=False)
130
131
        self.add_module(self.norm3_name, norm3)

Kai Chen's avatar
Kai Chen committed
132
133
134
135
136
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
        self.with_cp = with_cp
137
        self.normalize = normalize
Kai Chen's avatar
Kai Chen committed
138

ThangVu's avatar
ThangVu committed
139
140
141
142
143
144
145
146
147
148
149
150
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    @property
    def norm3(self):
        return getattr(self, self.norm3_name)

Kai Chen's avatar
Kai Chen committed
151
152
153
    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
154
            identity = x
Kai Chen's avatar
Kai Chen committed
155
156

            out = self.conv1(x)
ThangVu's avatar
ThangVu committed
157
            out = self.norm1(out)
Kai Chen's avatar
Kai Chen committed
158
159
160
            out = self.relu(out)

            out = self.conv2(out)
ThangVu's avatar
ThangVu committed
161
            out = self.norm2(out)
Kai Chen's avatar
Kai Chen committed
162
163
164
            out = self.relu(out)

            out = self.conv3(out)
ThangVu's avatar
ThangVu committed
165
            out = self.norm3(out)
Kai Chen's avatar
Kai Chen committed
166
167

            if self.downsample is not None:
pangjm's avatar
pangjm committed
168
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
169

pangjm's avatar
pangjm committed
170
            out += identity
Kai Chen's avatar
Kai Chen committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
190
                   style='pytorch',
191
                   with_cp=False,
ThangVu's avatar
ThangVu committed
192
                   normalize=dict(type='BN')):
Kai Chen's avatar
Kai Chen committed
193
194
195
196
197
198
199
200
201
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
ThangVu's avatar
ThangVu committed
202
            build_norm_layer(normalize, planes * block.expansion)[1],
Kai Chen's avatar
Kai Chen committed
203
204
205
206
207
208
209
210
211
212
213
        )

    layers = []
    layers.append(
        block(
            inplanes,
            planes,
            stride,
            dilation,
            downsample,
            style=style,
214
            with_cp=with_cp,
ThangVu's avatar
ThangVu committed
215
            normalize=normalize))
Kai Chen's avatar
Kai Chen committed
216
217
218
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
219
220
            block(inplanes, planes, 1, dilation, style=style,
                  with_cp=with_cp, normalize=normalize))
Kai Chen's avatar
Kai Chen committed
221
222
223
224

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
225
226
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
227

Kai Chen's avatar
Kai Chen committed
228
229
230
231
232
233
234
235
236
237
238
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
            not freezing any parameters.
thangvu's avatar
thangvu committed
239
240
241
        normalize (dict): dictionary to construct norm layer. Additionally,
            eval mode and gradent freezing are controlled by
            eval (bool) and frozen (bool) respectively.
Kai Chen's avatar
Kai Chen committed
242
243
244
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
    """
Kai Chen's avatar
Kai Chen committed
245

Kai Chen's avatar
Kai Chen committed
246
247
248
249
250
251
252
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
253
254

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
255
256
                 depth,
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
257
258
259
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
260
                 style='pytorch',
ThangVu's avatar
ThangVu committed
261
                 frozen_stages=-1,
262
263
                 normalize=dict(
                     type='BN',
ThangVu's avatar
ThangVu committed
264
                     eval_mode=True,
thangvu's avatar
thangvu committed
265
                     frozen=False),
ThangVu's avatar
ThangVu committed
266
267
                 with_cp=False,
                 zero_init_residual=True):
Kai Chen's avatar
Kai Chen committed
268
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
269
270
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
271
272
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
273
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
274
275
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
276
        assert len(strides) == len(dilations) == num_stages
Kai Chen's avatar
Kai Chen committed
277
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
278
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
279
        self.style = style
ThangVu's avatar
ThangVu committed
280
281
        self.frozen_stages = frozen_stages
        assert (isinstance(normalize, dict) and 'eval_mode' in normalize
thangvu's avatar
thangvu committed
282
                and 'frozen' in normalize)
ThangVu's avatar
ThangVu committed
283
        self.norm_eval = normalize.pop('eval_mode')
284
        self.normalize = normalize
ThangVu's avatar
ThangVu committed
285
286
        self.with_cp = with_cp
        self.zero_init_residual = zero_init_residual
pangjm's avatar
pangjm committed
287
288
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
289
        self.inplanes = 64
pangjm's avatar
pangjm committed
290

thangvu's avatar
thangvu committed
291
        self._make_stem_layer()
Kai Chen's avatar
Kai Chen committed
292

Kai Chen's avatar
Kai Chen committed
293
        self.res_layers = []
pangjm's avatar
pangjm committed
294
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
295
296
297
298
            stride = strides[i]
            dilation = dilations[i]
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
299
                self.block,
Kai Chen's avatar
Kai Chen committed
300
301
302
303
304
305
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
306
                with_cp=with_cp,
ThangVu's avatar
ThangVu committed
307
                normalize=normalize)
pangjm's avatar
pangjm committed
308
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
309
            layer_name = 'layer{}'.format(i + 1)
310
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
311
312
            self.res_layers.append(layer_name)

ThangVu's avatar
ThangVu committed
313
314
        self._freeze_stages()

pangjm's avatar
pangjm committed
315
316
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
317

ThangVu's avatar
ThangVu committed
318
319
320
321
    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

thangvu's avatar
thangvu committed
322
323
324
    def _make_stem_layer(self):
        self.conv1 = nn.Conv2d(
            3, 64, kernel_size=7, stride=2, padding=3, bias=False)
ThangVu's avatar
ThangVu committed
325
326
327
        self.norm1_name, norm1 = build_norm_layer(self.normalize,
                                                  64, postfix=1)
        self.add_module(self.norm1_name, norm1)
thangvu's avatar
thangvu committed
328
329
330
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

ThangVu's avatar
ThangVu committed
331
    def _freeze_stages(self):
ThangVu's avatar
ThangVu committed
332
        if self.frozen_stages >= 0:
ThangVu's avatar
ThangVu committed
333
            for m in [self.conv1, self.norm1]:
ThangVu's avatar
ThangVu committed
334
                for param in m.parameters():
thangvu's avatar
thangvu committed
335
336
                    param.requires_grad = False

ThangVu's avatar
ThangVu committed
337
338
339
340
341
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
            for param in m.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
342
343
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
344
345
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
346
347
348
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
349
                    kaiming_init(m)
ThangVu's avatar
minor  
ThangVu committed
350
                elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
351
                    constant_init(m, 1)
352
353

            # zero init for last norm layer https://arxiv.org/abs/1706.02677
ThangVu's avatar
ThangVu committed
354
355
            if self.zero_init_residual:
                for m in self.modules():
ThangVu's avatar
ThangVu committed
356
357
358
359
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
Kai Chen's avatar
Kai Chen committed
360
361
362
363
364
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
ThangVu's avatar
ThangVu committed
365
        x = self.norm1(x)
Kai Chen's avatar
Kai Chen committed
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
        if len(outs) == 1:
            return outs[0]
        else:
            return tuple(outs)

    def train(self, mode=True):
        super(ResNet, self).train(mode)
thangvu's avatar
thangvu committed
381
        if mode and self.norm_eval:
ThangVu's avatar
ThangVu committed
382
            for m in self.modules():
thangvu's avatar
thangvu committed
383
                # trick: eval have effect on BatchNorm only
ThangVu's avatar
ThangVu committed
384
385
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()