densenet.py 11.6 KB
Newer Older
1
import re
Geoff Pleiss's avatar
Geoff Pleiss committed
2
3
4
import torch
import torch.nn as nn
import torch.nn.functional as F
5
import torch.utils.checkpoint as cp
Geoff Pleiss's avatar
Geoff Pleiss committed
6
from collections import OrderedDict
7
from .utils import load_state_dict_from_url
eellison's avatar
eellison committed
8
9
from torch import Tensor
from torch.jit.annotations import List
10

Geoff Pleiss's avatar
Geoff Pleiss committed
11
12
13
14

__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']

model_urls = {
15
16
17
18
    'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
Geoff Pleiss's avatar
Geoff Pleiss committed
19
20
21
}


eellison's avatar
eellison committed
22
class _DenseLayer(nn.Module):
23
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):
24
25
26
27
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
28
29
                                           growth_rate, kernel_size=1, stride=1,
                                           bias=False)),
30
31
32
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
33
34
                                           kernel_size=3, stride=1, padding=1,
                                           bias=False)),
eellison's avatar
eellison committed
35
        self.drop_rate = float(drop_rate)
36
37
        self.memory_efficient = memory_efficient

eellison's avatar
eellison committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    def bn_function(self, inputs):
        # type: (List[Tensor]) -> Tensor
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))  # noqa: T484
        return bottleneck_output

    # todo: rewrite when torchscript supports any
    def any_requires_grad(self, input):
        # type: (List[Tensor]) -> bool
        for tensor in input:
            if tensor.requires_grad:
                return True
        return False

    @torch.jit.unused  # noqa: T484
    def call_checkpoint_bottleneck(self, input):
        # type: (List[Tensor]) -> Tensor
        def closure(*inputs):
            return self.bn_function(*inputs)

        return cp.checkpoint(closure, input)

    @torch.jit._overload_method  # noqa: F811
    def forward(self, input):
        # type: (List[Tensor]) -> (Tensor)
        pass

    @torch.jit._overload_method  # noqa: F811
    def forward(self, input):
        # type: (Tensor) -> (Tensor)
        pass

    # torchscript does not yet support *args, so we overload method
    # allowing it to take either a List[Tensor] or single Tensor
    def forward(self, input):  # noqa: F811
        if isinstance(input, Tensor):
            prev_features = [input]
75
        else:
eellison's avatar
eellison committed
76
77
78
79
80
81
82
83
84
85
            prev_features = input

        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("Memory Efficient not supported in JIT")

            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

86
        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
87
        if self.drop_rate > 0:
88
89
            new_features = F.dropout(new_features, p=self.drop_rate,
                                     training=self.training)
90
        return new_features
91
92


eellison's avatar
eellison committed
93
class _DenseBlock(nn.ModuleDict):
eellison's avatar
eellison committed
94
95
96
    _version = 2
    __constants__ = ['layers']

97
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
98
99
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
100
101
102
103
104
105
106
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
            )
eellison's avatar
eellison committed
107
            self.add_module('denselayer%d' % (i + 1), layer)
108

109
110
    def forward(self, init_features):
        features = [init_features]
eellison's avatar
eellison committed
111
        for name, layer in self.items():
eellison's avatar
eellison committed
112
            new_features = layer(features)
113
114
115
            features.append(new_features)
        return torch.cat(features, 1)

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
139
140
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
141
142
    """

eellison's avatar
eellison committed
143
144
    __constants__ = ['features']

145
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
146
                 num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False):
147
148
149
150
151

        super(DenseNet, self).__init__()

        # First convolution
        self.features = nn.Sequential(OrderedDict([
152
153
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
                                padding=3, bias=False)),
154
155
156
157
158
159
160
161
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
162
163
164
165
166
167
168
169
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient
            )
170
171
172
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
173
174
                trans = _Transition(num_input_features=num_features,
                                    num_output_features=num_features // 2)
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
197
198
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
199
200
201
202
        out = self.classifier(out)
        return out


203
def _load_state_dict(model, model_url, progress):
204
    # '.'s are no longer allowed in module names, but previous _DenseLayer
205
206
207
208
209
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
        r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
210
211

    state_dict = load_state_dict_from_url(model_url, progress=progress)
212
213
214
215
216
217
218
219
220
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)


221
222
223
224
225
226
227
228
229
def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress,
              **kwargs):
    model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
    if pretrained:
        _load_state_dict(model, model_urls[arch], progress)
    return model


def densenet121(pretrained=False, progress=True, **kwargs):
Geoff Pleiss's avatar
Geoff Pleiss committed
230
    r"""Densenet-121 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
231
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
232
233
234

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
235
        progress (bool): If True, displays a progress bar of the download to stderr
236
237
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
238
    """
239
    return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress,
240
                     **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
241
242


243
244
def densenet161(pretrained=False, progress=True, **kwargs):
    r"""Densenet-161 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
245
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
246
247
248

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
249
        progress (bool): If True, displays a progress bar of the download to stderr
250
251
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
252
    """
253
    return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress,
254
                     **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
255
256


257
258
def densenet169(pretrained=False, progress=True, **kwargs):
    r"""Densenet-169 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
259
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
260
261
262

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
263
        progress (bool): If True, displays a progress bar of the download to stderr
264
265
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
266
    """
267
    return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress,
268
                     **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
269
270


271
272
def densenet201(pretrained=False, progress=True, **kwargs):
    r"""Densenet-201 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
273
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
274
275
276

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
277
        progress (bool): If True, displays a progress bar of the download to stderr
278
279
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
280
    """
281
    return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress,
282
                     **kwargs)