densenet.py 12.5 KB
Newer Older
1
import re
Geoff Pleiss's avatar
Geoff Pleiss committed
2
3
4
import torch
import torch.nn as nn
import torch.nn.functional as F
5
import torch.utils.checkpoint as cp
Geoff Pleiss's avatar
Geoff Pleiss committed
6
from collections import OrderedDict
7
from .utils import load_state_dict_from_url
eellison's avatar
eellison committed
8
9
from torch import Tensor
from torch.jit.annotations import List
10

Geoff Pleiss's avatar
Geoff Pleiss committed
11
12
13
14

__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']

model_urls = {
15
16
17
18
    'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
Geoff Pleiss's avatar
Geoff Pleiss committed
19
20
21
}


eellison's avatar
eellison committed
22
class _DenseLayer(nn.Module):
23
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):
24
25
26
27
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
28
29
                                           growth_rate, kernel_size=1, stride=1,
                                           bias=False)),
30
31
32
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
33
34
                                           kernel_size=3, stride=1, padding=1,
                                           bias=False)),
eellison's avatar
eellison committed
35
        self.drop_rate = float(drop_rate)
36
37
        self.memory_efficient = memory_efficient

eellison's avatar
eellison committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    def bn_function(self, inputs):
        # type: (List[Tensor]) -> Tensor
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))  # noqa: T484
        return bottleneck_output

    # todo: rewrite when torchscript supports any
    def any_requires_grad(self, input):
        # type: (List[Tensor]) -> bool
        for tensor in input:
            if tensor.requires_grad:
                return True
        return False

    @torch.jit.unused  # noqa: T484
    def call_checkpoint_bottleneck(self, input):
        # type: (List[Tensor]) -> Tensor
        def closure(*inputs):
            return self.bn_function(*inputs)

        return cp.checkpoint(closure, input)

    @torch.jit._overload_method  # noqa: F811
    def forward(self, input):
        # type: (List[Tensor]) -> (Tensor)
        pass

    @torch.jit._overload_method  # noqa: F811
    def forward(self, input):
        # type: (Tensor) -> (Tensor)
        pass

    # torchscript does not yet support *args, so we overload method
    # allowing it to take either a List[Tensor] or single Tensor
    def forward(self, input):  # noqa: F811
        if isinstance(input, Tensor):
            prev_features = [input]
75
        else:
eellison's avatar
eellison committed
76
77
78
79
80
81
82
83
84
85
            prev_features = input

        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("Memory Efficient not supported in JIT")

            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

86
        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
87
        if self.drop_rate > 0:
88
89
            new_features = F.dropout(new_features, p=self.drop_rate,
                                     training=self.training)
90
        return new_features
91
92


93
class _DenseBlock(nn.Module):
eellison's avatar
eellison committed
94
95
96
    _version = 2
    __constants__ = ['layers']

97
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
98
        super(_DenseBlock, self).__init__()
eellison's avatar
eellison committed
99
        self.layers = nn.ModuleDict()
100
        for i in range(num_layers):
101
102
103
104
105
106
107
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
            )
eellison's avatar
eellison committed
108
            self.layers['denselayer%d' % (i + 1)] = layer
109

110
111
    def forward(self, init_features):
        features = [init_features]
eellison's avatar
eellison committed
112
113
        for name, layer in self.layers.items():
            new_features = layer(features)
114
115
116
            features.append(new_features)
        return torch.cat(features, 1)

eellison's avatar
eellison committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    @torch.jit.ignore
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        version = local_metadata.get('version', None)
        if (version is None or version < 2):
            # now we have a new nesting level for torchscript support
            for new_key in self.state_dict().keys():
                # remove prefix "layers."
                old_key = new_key[len("layers."):]
                old_key = prefix + old_key
                new_key = prefix + new_key
                if old_key in state_dict:
                    value = state_dict[old_key]
                    del state_dict[old_key]
                    state_dict[new_key] = value
        super(_DenseBlock, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
159
160
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
161
162
    """

eellison's avatar
eellison committed
163
164
    __constants__ = ['features']

165
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
166
                 num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False):
167
168
169
170
171

        super(DenseNet, self).__init__()

        # First convolution
        self.features = nn.Sequential(OrderedDict([
172
173
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
                                padding=3, bias=False)),
174
175
176
177
178
179
180
181
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
182
183
184
185
186
187
188
189
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient
            )
190
191
192
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
193
194
                trans = _Transition(num_input_features=num_features,
                                    num_output_features=num_features // 2)
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
217
218
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
219
220
221
222
        out = self.classifier(out)
        return out


223
def _load_state_dict(model, model_url, progress):
224
    # '.'s are no longer allowed in module names, but previous _DenseLayer
225
226
227
228
229
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
        r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
230
231

    state_dict = load_state_dict_from_url(model_url, progress=progress)
232
233
234
235
236
237
238
239
240
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)


241
242
243
244
245
246
247
248
249
def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress,
              **kwargs):
    model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
    if pretrained:
        _load_state_dict(model, model_urls[arch], progress)
    return model


def densenet121(pretrained=False, progress=True, **kwargs):
Geoff Pleiss's avatar
Geoff Pleiss committed
250
    r"""Densenet-121 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
251
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
252
253
254

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
255
        progress (bool): If True, displays a progress bar of the download to stderr
256
257
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
258
    """
259
    return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress,
260
                     **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
261
262


263
264
def densenet161(pretrained=False, progress=True, **kwargs):
    r"""Densenet-161 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
265
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
266
267
268

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
269
        progress (bool): If True, displays a progress bar of the download to stderr
270
271
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
272
    """
273
    return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress,
274
                     **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
275
276


277
278
def densenet169(pretrained=False, progress=True, **kwargs):
    r"""Densenet-169 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
279
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
280
281
282

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
283
        progress (bool): If True, displays a progress bar of the download to stderr
284
285
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
286
    """
287
    return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress,
288
                     **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
289
290


291
292
def densenet201(pretrained=False, progress=True, **kwargs):
    r"""Densenet-201 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
293
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
294
295
296

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
297
        progress (bool): If True, displays a progress bar of the download to stderr
298
299
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
300
    """
301
    return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress,
302
                     **kwargs)