resnet.py 11.7 KB
Newer Older
1
2
import logging

Kai Chen's avatar
Kai Chen committed
3
4
import torch.nn as nn
import torch.utils.checkpoint as cp
Kai Chen's avatar
Kai Chen committed
5
6

from mmcv.cnn import constant_init, kaiming_init
Kai Chen's avatar
Kai Chen committed
7
from mmcv.runner import load_checkpoint
8
from ..utils import build_norm_layer
Kai Chen's avatar
Kai Chen committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


def conv3x3(in_planes, out_planes, stride=1, dilation=1):
    "3x3 convolution with padding"
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        dilation=dilation,
        bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
Kai Chen's avatar
Kai Chen committed
32
                 style='pytorch',
33
                 with_cp=False,
thangvu's avatar
thangvu committed
34
35
                 normalize=dict(type='BN'),
                 frozen=False):
Kai Chen's avatar
Kai Chen committed
36
37
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride, dilation)
38
39
40
41
42
43
44
45
46

        norm_layers = []
        norm_layers.append(build_norm_layer(normalize, planes))
        norm_layers.append(build_norm_layer(normalize, planes))
        self.norm_names = (['gn1', 'gn2'] if normalize['type'] == 'GN'
                           else ['bn1', 'bn2'])
        for name, layer in zip(self.norm_names, norm_layers):
            self.add_module(name, layer)

Kai Chen's avatar
Kai Chen committed
47
48
49
50
51
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
Kai Chen's avatar
Kai Chen committed
52
        assert not with_cp
Kai Chen's avatar
Kai Chen committed
53

thangvu's avatar
thangvu committed
54
55
56
57
        if frozen:
            for param in self.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
58
    def forward(self, x):
pangjm's avatar
pangjm committed
59
        identity = x
Kai Chen's avatar
Kai Chen committed
60
61

        out = self.conv1(x)
62
        out = getattr(self, self.norm_names[0])(out)
Kai Chen's avatar
Kai Chen committed
63
64
65
        out = self.relu(out)

        out = self.conv2(out)
66
        out = getattr(self, self.norm_names[1])(out)
Kai Chen's avatar
Kai Chen committed
67
68

        if self.downsample is not None:
pangjm's avatar
pangjm committed
69
            identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
70

pangjm's avatar
pangjm committed
71
        out += identity
Kai Chen's avatar
Kai Chen committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
86
                 style='pytorch',
87
                 with_cp=False,
thangvu's avatar
thangvu committed
88
89
                 normalize=dict(type='BN'),
                 frozen=False):
pangjm's avatar
pangjm committed
90
        """Bottleneck block for ResNet.
91
92
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
Kai Chen's avatar
Kai Chen committed
93
94
        """
        super(Bottleneck, self).__init__()
95
        assert style in ['pytorch', 'caffe']
pangjm's avatar
pangjm committed
96
97
        self.inplanes = inplanes
        self.planes = planes
98
        if style == 'pytorch':
pangjm's avatar
pangjm committed
99
100
            self.conv1_stride = 1
            self.conv2_stride = stride
Kai Chen's avatar
Kai Chen committed
101
        else:
pangjm's avatar
pangjm committed
102
103
            self.conv1_stride = stride
            self.conv2_stride = 1
Kai Chen's avatar
Kai Chen committed
104
        self.conv1 = nn.Conv2d(
pangjm's avatar
pangjm committed
105
106
107
108
109
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
Kai Chen's avatar
Kai Chen committed
110
111
112
113
        self.conv2 = nn.Conv2d(
            planes,
            planes,
            kernel_size=3,
pangjm's avatar
pangjm committed
114
            stride=self.conv2_stride,
Kai Chen's avatar
Kai Chen committed
115
116
117
118
            padding=dilation,
            dilation=dilation,
            bias=False)

119
120
121
122
123
124
125
126
127
        norm_layers = []
        norm_layers.append(build_norm_layer(normalize, planes))
        norm_layers.append(build_norm_layer(normalize, planes))
        norm_layers.append(build_norm_layer(normalize, planes*self.expansion))
        self.norm_names = (['gn1', 'gn2', 'gn3'] if normalize['type'] == 'GN'
                           else ['bn1', 'bn2', 'bn3'])
        for name, layer in zip(self.norm_names, norm_layers):
            self.add_module(name, layer)

Kai Chen's avatar
Kai Chen committed
128
129
130
131
132
133
134
        self.conv3 = nn.Conv2d(
            planes, planes * self.expansion, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
        self.with_cp = with_cp
135
        self.normalize = normalize
Kai Chen's avatar
Kai Chen committed
136

thangvu's avatar
thangvu committed
137
138
139
140
        if frozen:
            for param in self.parameters():
                param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
141
142
143
    def forward(self, x):

        def _inner_forward(x):
pangjm's avatar
pangjm committed
144
            identity = x
Kai Chen's avatar
Kai Chen committed
145
146

            out = self.conv1(x)
147
            out = getattr(self, self.norm_names[0])(out)
Kai Chen's avatar
Kai Chen committed
148
149
150
            out = self.relu(out)

            out = self.conv2(out)
151
            out = getattr(self, self.norm_names[1])(out)
Kai Chen's avatar
Kai Chen committed
152
153
154
            out = self.relu(out)

            out = self.conv3(out)
155
            out = getattr(self, self.norm_names[2])(out)
Kai Chen's avatar
Kai Chen committed
156
157

            if self.downsample is not None:
pangjm's avatar
pangjm committed
158
                identity = self.downsample(x)
Kai Chen's avatar
Kai Chen committed
159

pangjm's avatar
pangjm committed
160
            out += identity
Kai Chen's avatar
Kai Chen committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out


def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   stride=1,
                   dilation=1,
180
                   style='pytorch',
181
                   with_cp=False,
thangvu's avatar
thangvu committed
182
183
                   normalize=dict(type='BN'),
                   frozen=False):
Kai Chen's avatar
Kai Chen committed
184
185
186
187
188
189
190
191
192
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(
                inplanes,
                planes * block.expansion,
                kernel_size=1,
                stride=stride,
                bias=False),
193
            build_norm_layer(normalize, planes * block.expansion),
Kai Chen's avatar
Kai Chen committed
194
195
196
197
198
199
200
201
202
203
204
        )

    layers = []
    layers.append(
        block(
            inplanes,
            planes,
            stride,
            dilation,
            downsample,
            style=style,
205
            with_cp=with_cp,
thangvu's avatar
thangvu committed
206
207
            normalize=normalize,
            frozen=frozen))
Kai Chen's avatar
Kai Chen committed
208
209
210
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
211
212
            block(inplanes, planes, 1, dilation, style=style,
                  with_cp=with_cp, normalize=normalize))
Kai Chen's avatar
Kai Chen committed
213
214
215
216

    return nn.Sequential(*layers)


Kai Chen's avatar
Kai Chen committed
217
218
class ResNet(nn.Module):
    """ResNet backbone.
Kai Chen's avatar
Kai Chen committed
219

Kai Chen's avatar
Kai Chen committed
220
221
222
223
224
225
226
227
228
229
230
    Args:
        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
        num_stages (int): Resnet stages, normally 4.
        strides (Sequence[int]): Strides of the first block of each stage.
        dilations (Sequence[int]): Dilation of each stage.
        out_indices (Sequence[int]): Output from which stages.
        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
            layer is the 3x3 conv layer, otherwise the stride-two layer is
            the first 1x1 conv layer.
        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
            not freezing any parameters.
thangvu's avatar
thangvu committed
231
232
233
        normalize (dict): dictionary to construct norm layer. Additionally,
            eval mode and gradent freezing are controlled by
            eval (bool) and frozen (bool) respectively.
Kai Chen's avatar
Kai Chen committed
234
235
236
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed.
    """
Kai Chen's avatar
Kai Chen committed
237

Kai Chen's avatar
Kai Chen committed
238
239
240
241
242
243
244
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
Kai Chen's avatar
Kai Chen committed
245
246

    def __init__(self,
Kai Chen's avatar
Kai Chen committed
247
248
                 depth,
                 num_stages=4,
Kai Chen's avatar
Kai Chen committed
249
250
251
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
252
                 style='pytorch',
ThangVu's avatar
ThangVu committed
253
                 frozen_stages=-1,
254
255
                 normalize=dict(
                     type='BN',
thangvu's avatar
thangvu committed
256
257
                     eval=True,
                     frozen=False),
Kai Chen's avatar
Kai Chen committed
258
                 with_cp=False):
Kai Chen's avatar
Kai Chen committed
259
        super(ResNet, self).__init__()
Kai Chen's avatar
Kai Chen committed
260
261
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
pangjm's avatar
pangjm committed
262
263
        self.depth = depth
        self.num_stages = num_stages
Kai Chen's avatar
Kai Chen committed
264
        assert num_stages >= 1 and num_stages <= 4
pangjm's avatar
pangjm committed
265
266
        self.strides = strides
        self.dilations = dilations
Kai Chen's avatar
Kai Chen committed
267
        assert len(strides) == len(dilations) == num_stages
Kai Chen's avatar
Kai Chen committed
268
        self.out_indices = out_indices
pangjm's avatar
pangjm committed
269
        assert max(out_indices) < num_stages
Kai Chen's avatar
Kai Chen committed
270
        self.style = style
ThangVu's avatar
ThangVu committed
271
        self.with_cp = with_cp
thangvu's avatar
thangvu committed
272
273
274
275
        self.is_frozen = [i <= frozen_stages for i in range(num_stages + 1)]
        assert (isinstance(normalize, dict) and 'eval' in normalize
                and 'frozen' in normalize)
        self.norm_eval = normalize.pop('eval')
276
        self.normalize = normalize
pangjm's avatar
pangjm committed
277
278
        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
Kai Chen's avatar
Kai Chen committed
279
        self.inplanes = 64
pangjm's avatar
pangjm committed
280

thangvu's avatar
thangvu committed
281
        self._make_stem_layer()
Kai Chen's avatar
Kai Chen committed
282

Kai Chen's avatar
Kai Chen committed
283
        self.res_layers = []
pangjm's avatar
pangjm committed
284
        for i, num_blocks in enumerate(self.stage_blocks):
Kai Chen's avatar
Kai Chen committed
285
286
287
288
            stride = strides[i]
            dilation = dilations[i]
            planes = 64 * 2**i
            res_layer = make_res_layer(
pangjm's avatar
pangjm committed
289
                self.block,
Kai Chen's avatar
Kai Chen committed
290
291
292
293
294
295
                self.inplanes,
                planes,
                num_blocks,
                stride=stride,
                dilation=dilation,
                style=self.style,
296
                with_cp=with_cp,
thangvu's avatar
thangvu committed
297
298
                normalize=normalize,
                frozen=self.is_frozen[i + 1])
pangjm's avatar
pangjm committed
299
            self.inplanes = planes * self.block.expansion
Kai Chen's avatar
Kai Chen committed
300
            layer_name = 'layer{}'.format(i + 1)
301
            self.add_module(layer_name, res_layer)
Kai Chen's avatar
Kai Chen committed
302
303
            self.res_layers.append(layer_name)

pangjm's avatar
pangjm committed
304
305
        self.feat_dim = self.block.expansion * 64 * 2**(
            len(self.stage_blocks) - 1)
pangjm's avatar
pangjm committed
306

thangvu's avatar
thangvu committed
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    def _make_stem_layer(self):
        self.conv1 = nn.Conv2d(
            3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        stem_norm = build_norm_layer(self.normalize, 64)
        self.norm_name = 'gn1' if self.normalize['type'] == 'GN' else 'bn1'
        self.add_module(self.norm_name, stem_norm)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        if self.is_frozen[0]:
            for layer in [self.conv1, stem_norm]:
                for param in layer.parameters():
                    param.requires_grad = False

Kai Chen's avatar
Kai Chen committed
321
322
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
323
324
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
Kai Chen's avatar
Kai Chen committed
325
326
327
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
Kai Chen's avatar
Kai Chen committed
328
                    kaiming_init(m)
thangvu's avatar
thangvu committed
329
330
                elif (isinstance(m, nn.BatchNorm2d)
                      or isinstance(m, nn.GroupNorm)):
Kai Chen's avatar
Kai Chen committed
331
                    constant_init(m, 1)
332
333
334
335
336
337

            # zero init for last norm layer https://arxiv.org/abs/1706.02677
            for m in self.modules():
                if isinstance(m, Bottleneck) or isinstance(m, BasicBlock):
                    last_norm = getattr(m, m.norm_names[-1])
                    constant_init(last_norm, 0)
Kai Chen's avatar
Kai Chen committed
338
339
340
341
342
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
thangvu's avatar
thangvu committed
343
        x = getattr(self, self.norm_name)(x)
Kai Chen's avatar
Kai Chen committed
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
        x = self.relu(x)
        x = self.maxpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x)
        if len(outs) == 1:
            return outs[0]
        else:
            return tuple(outs)

    def train(self, mode=True):
        super(ResNet, self).train(mode)
thangvu's avatar
thangvu committed
359
360
361
362
363
        if mode and self.norm_eval:
            for mod in self.modules():
                # trick: eval have effect on BatchNorm only
                if isinstance(self, nn.BatchNorm2d):
                    mod.eval()