resnet.py 11.6 KB
Newer Older
1
2
import logging

Kai Chen's avatar
Kai Chen committed
3
4
import torch.nn as nn
import torch.utils.checkpoint as cp
Kai Chen's avatar
Kai Chen committed
5
6

from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
7
from mmcv.runner import load_checkpoint
8
from ..utils import build_norm_layer
Kai Chen's avatar
Kai Chen committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


def conv3x3(in_planes, out_planes, stride=1, dilation=1):
    "3x3 convolution with padding"
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        dilation=dilation,
        bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
32
                 style='pytorch',
33
                 with_cp=False,
ThangVu's avatar
ThangVu committed
34
                 normalize=dict(type='BN')):
Kai Chen's avatar
Kai Chen committed
35
36
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride, dilation)
37
38
39
40
41
42
43
44
45

        norm_layers = []
        norm_layers.append(build_norm_layer(normalize, planes))
        norm_layers.append(build_norm_layer(normalize, planes))
        self.norm_names = (['gn1', 'gn2'] if normalize['type'] == 'GN'
                           else ['bn1', 'bn2'])
        for name, layer in zip(self.norm_names, norm_layers):
            self.add_module(name, layer)

Kai Chen's avatar
Kai Chen committed
46
47
48
49
50
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
51
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
52
53

    def forward(self, x):
pangjm's avatar
pangjm committed
54
        identity = x
Kai Chen's avatar
Kai Chen committed
55
56

        out = self.conv1(x)
57
        out = getattr(self, self.norm_names[0])(out)
Kai Chen's avatar
Kai Chen committed
58
59
60
        out = self.relu(out)

        out = self.conv2(out)
61
        out = getattr(self, self.norm_names[1])(out)
Kai Chen's avatar
Kai Chen committed
62
63

        if self.downsample is not None:
pangjm's avatar
pangjm committed
64
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
65

pangjm's avatar
pangjm committed
66
        out += identity
Kai Chen's avatar
Kai Chen committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
81
                 style='pytorch',
82
                 with_cp=False,
ThangVu's avatar
ThangVu committed
83
                 normalize=dict(type='BN')):
pangjm's avatar
pangjm committed
84
        """Bottleneck block for ResNet.
85
86
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
87
88
        """
        super(Bottleneck, self).__init__()
89
        assert style in ['pytorch', 'caffe']
pangjm's avatar
pangjm committed
90
91
        self.inplanes = inplanes
        self.planes = planes
92
        if style == 'pytorch':
pangjm's avatar
pangjm committed
93
94
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
95
        else:
pangjm's avatar
pangjm committed
96
97
            self.conv1_stride = stride
            self.conv2_stride = 1
Kai Chen's avatar
Kai Chen committed
98
        self.conv1 = nn.Conv2d(
pangjm's avatar
pangjm committed
99
100
101
102
103
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
Kai Chen's avatar
Kai Chen committed
104
105
106
107
        self.conv2 = nn.Conv2d(
            planes,
            planes,
            kernel_size=3,
pangjm's avatar
pangjm committed
108
            stride=self.conv2_stride,
Kai Chen's avatar
Kai Chen committed
109
110
111
112
            padding=dilation,
            dilation=dilation,
            bias=False)

113
114
115
116
117
118
119
120
121
        norm_layers = []
        norm_layers.append(build_norm_layer(normalize, planes))
        norm_layers.append(build_norm_layer(normalize, planes))
        norm_layers.append(build_norm_layer(normalize, planes*self.expansion))
        self.norm_names = (['gn1', 'gn2', 'gn3'] if normalize['type'] == 'GN'
                           else ['bn1', 'bn2', 'bn3'])
        for name, layer in zip(self.norm_names, norm_layers):
            self.add_module(name, layer)

Kai Chen's avatar
Kai Chen committed
122
123
124
125
126
127
128
        self.conv3 = nn.Conv2d(
            planes, planes * self.expansion, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
        self.with_cp = with_cp
129
        self.normalize = normalize
Kai Chen's avatar
Kai Chen committed
130
131
132
133

    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
134
            identity = x
Kai Chen's avatar
Kai Chen committed
135
136

            out = self.conv1(x)
137
            out = getattr(self, self.norm_names[0])(out)
Kai Chen's avatar
Kai Chen committed
138
139
140
            out = self.relu(out)

            out = self.conv2(out)
141
            out = getattr(self, self.norm_names[1])(out)
Kai Chen's avatar
Kai Chen committed
142
143
144
            out = self.relu(out)

            out = self.conv3(out)
145
            out = getattr(self, self.norm_names[2])(out)
Kai Chen's avatar
Kai Chen committed
146
147

            if self.downsample is not None:
pangjm's avatar
pangjm committed
148
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
149

pangjm's avatar
pangjm committed
150
            out += identity
Kai Chen's avatar
Kai Chen committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
170
                   style='pytorch',
171
                   with_cp=False,
ThangVu's avatar
ThangVu committed
172
                   normalize=dict(type='BN')):
Kai Chen's avatar
Kai Chen committed
173
174
175
176
177
178
179
180
181
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
182
            build_norm_layer(normalize, planes * block.expansion),
Kai Chen's avatar
Kai Chen committed
183
184
185
186
187
188
189
190
191
192
193
        )

    layers = []
    layers.append(
        block(
            inplanes,
            planes,
            stride,
            dilation,
            downsample,
            style=style,
194
            with_cp=with_cp,
ThangVu's avatar
ThangVu committed
195
            normalize=normalize))
Kai Chen's avatar
Kai Chen committed
196
197
198
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
199
200
            block(inplanes, planes, 1, dilation, style=style,
                  with_cp=with_cp, normalize=normalize))
Kai Chen's avatar
Kai Chen committed
201
202
203
204

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
205
206
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
207

Kai Chen's avatar
Kai Chen committed
208
209
210
211
212
213
214
215
216
217
218
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
            not freezing any parameters.
thangvu's avatar
thangvu committed
219
220
221
        normalize (dict): dictionary to construct norm layer. Additionally,
            eval mode and gradent freezing are controlled by
            eval (bool) and frozen (bool) respectively.
Kai Chen's avatar
Kai Chen committed
222
223
224
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
    """
Kai Chen's avatar
Kai Chen committed
225

Kai Chen's avatar
Kai Chen committed
226
227
228
229
230
231
232
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
233
234

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
235
236
                 depth,
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
237
238
239
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
240
                 style='pytorch',
ThangVu's avatar
ThangVu committed
241
                 frozen_stages=-1,
242
243
                 normalize=dict(
                     type='BN',
ThangVu's avatar
ThangVu committed
244
                     eval_mode=True,
thangvu's avatar
thangvu committed
245
                     frozen=False),
ThangVu's avatar
ThangVu committed
246
247
                 with_cp=False,
                 zero_init_residual=True):
Kai Chen's avatar
Kai Chen committed
248
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
249
250
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
251
252
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
253
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
254
255
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
256
        assert len(strides) == len(dilations) == num_stages
Kai Chen's avatar
Kai Chen committed
257
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
258
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
259
        self.style = style
ThangVu's avatar
ThangVu committed
260
261
        self.frozen_stages = frozen_stages
        assert (isinstance(normalize, dict) and 'eval_mode' in normalize
thangvu's avatar
thangvu committed
262
                and 'frozen' in normalize)
ThangVu's avatar
ThangVu committed
263
        self.norm_eval = normalize.pop('eval_mode')
264
        self.normalize = normalize
ThangVu's avatar
ThangVu committed
265
266
        self.with_cp = with_cp
        self.zero_init_residual = zero_init_residual
pangjm's avatar
pangjm committed
267
268
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
269
        self.inplanes = 64
pangjm's avatar
pangjm committed
270

thangvu's avatar
thangvu committed
271
        self._make_stem_layer()
Kai Chen's avatar
Kai Chen committed
272

Kai Chen's avatar
Kai Chen committed
273
        self.res_layers = []
pangjm's avatar
pangjm committed
274
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
275
276
277
278
            stride = strides[i]
            dilation = dilations[i]
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
279
                self.block,
Kai Chen's avatar
Kai Chen committed
280
281
282
283
284
285
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
286
                with_cp=with_cp,
ThangVu's avatar
ThangVu committed
287
                normalize=normalize)
pangjm's avatar
pangjm committed
288
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
289
            layer_name = 'layer{}'.format(i + 1)
290
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
291
292
            self.res_layers.append(layer_name)

ThangVu's avatar
ThangVu committed
293
294
        self._freeze_stages()

pangjm's avatar
pangjm committed
295
296
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
297

thangvu's avatar
thangvu committed
298
299
300
301
302
303
304
305
306
    def _make_stem_layer(self):
        self.conv1 = nn.Conv2d(
            3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        stem_norm = build_norm_layer(self.normalize, 64)
        self.norm_name = 'gn1' if self.normalize['type'] == 'GN' else 'bn1'
        self.add_module(self.norm_name, stem_norm)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

ThangVu's avatar
ThangVu committed
307
308
309
        if self.frozen_stages >= 0:
            for m in [self.conv1, stem_norm]:
                for param in m.parameters():
thangvu's avatar
thangvu committed
310
311
                    param.requires_grad = False

ThangVu's avatar
ThangVu committed
312
313
314
315
316
317
    def _freeze_stages(self):
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
            for param in m.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
318
319
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
320
321
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
322
323
324
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
325
                    kaiming_init(m)
ThangVu's avatar
ThangVu committed
326
                elif isinstance(m, (nn.BatchNorm, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
327
                    constant_init(m, 1)
328
329

            # zero init for last norm layer https://arxiv.org/abs/1706.02677
ThangVu's avatar
ThangVu committed
330
331
332
333
334
            if self.zero_init_residual:
                for m in self.modules():
                    if isinstance(m, (Bottleneck, BasicBlock)):
                        last_norm = getattr(m, m.norm_names[-1])
                        constant_init(last_norm, 0)
Kai Chen's avatar
Kai Chen committed
335
336
337
338
339
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
thangvu's avatar
thangvu committed
340
        x = getattr(self, self.norm_name)(x)
Kai Chen's avatar
Kai Chen committed
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
        if len(outs) == 1:
            return outs[0]
        else:
            return tuple(outs)

    def train(self, mode=True):
        super(ResNet, self).train(mode)
thangvu's avatar
thangvu committed
356
        if mode and self.norm_eval:
ThangVu's avatar
ThangVu committed
357
            for m in self.modules():
thangvu's avatar
thangvu committed
358
                # trick: eval have effect on BatchNorm only
ThangVu's avatar
ThangVu committed
359
360
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()