densenet.py 11.5 KB
Newer Older
1
import re
Geoff Pleiss's avatar
Geoff Pleiss committed
2
3
4
import torch
import torch.nn as nn
import torch.nn.functional as F
5
import torch.utils.checkpoint as cp
Geoff Pleiss's avatar
Geoff Pleiss committed
6
from collections import OrderedDict
7
from .utils import load_state_dict_from_url
eellison's avatar
eellison committed
8
9
from torch import Tensor
from torch.jit.annotations import List
10

Geoff Pleiss's avatar
Geoff Pleiss committed
11
12
13
14

__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']

model_urls = {
15
16
17
18
    'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
Geoff Pleiss's avatar
Geoff Pleiss committed
19
20
21
}


eellison's avatar
eellison committed
22
class _DenseLayer(nn.Module):
23
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):
24
25
26
27
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
28
29
                                           growth_rate, kernel_size=1, stride=1,
                                           bias=False)),
30
31
32
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
33
34
                                           kernel_size=3, stride=1, padding=1,
                                           bias=False)),
eellison's avatar
eellison committed
35
        self.drop_rate = float(drop_rate)
36
37
        self.memory_efficient = memory_efficient

eellison's avatar
eellison committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    def bn_function(self, inputs):
        # type: (List[Tensor]) -> Tensor
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))  # noqa: T484
        return bottleneck_output

    # todo: rewrite when torchscript supports any
    def any_requires_grad(self, input):
        # type: (List[Tensor]) -> bool
        for tensor in input:
            if tensor.requires_grad:
                return True
        return False

    @torch.jit.unused  # noqa: T484
    def call_checkpoint_bottleneck(self, input):
        # type: (List[Tensor]) -> Tensor
        def closure(*inputs):
            return self.bn_function(*inputs)

        return cp.checkpoint(closure, input)

    @torch.jit._overload_method  # noqa: F811
    def forward(self, input):
        # type: (List[Tensor]) -> (Tensor)
        pass

    @torch.jit._overload_method  # noqa: F811
    def forward(self, input):
        # type: (Tensor) -> (Tensor)
        pass

    # torchscript does not yet support *args, so we overload method
    # allowing it to take either a List[Tensor] or single Tensor
    def forward(self, input):  # noqa: F811
        if isinstance(input, Tensor):
            prev_features = [input]
75
        else:
eellison's avatar
eellison committed
76
77
78
79
80
81
82
83
84
85
            prev_features = input

        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("Memory Efficient not supported in JIT")

            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

86
        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
87
        if self.drop_rate > 0:
88
89
            new_features = F.dropout(new_features, p=self.drop_rate,
                                     training=self.training)
90
        return new_features
91
92


eellison's avatar
eellison committed
93
class _DenseBlock(nn.ModuleDict):
eellison's avatar
eellison committed
94
95
    _version = 2

96
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
97
98
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
99
100
101
102
103
104
105
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
            )
eellison's avatar
eellison committed
106
            self.add_module('denselayer%d' % (i + 1), layer)
107

108
109
    def forward(self, init_features):
        features = [init_features]
eellison's avatar
eellison committed
110
        for name, layer in self.items():
eellison's avatar
eellison committed
111
            new_features = layer(features)
112
113
114
            features.append(new_features)
        return torch.cat(features, 1)

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
138
139
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
140
141
142
    """

    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
143
                 num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False):
144
145
146
147
148

        super(DenseNet, self).__init__()

        # First convolution
        self.features = nn.Sequential(OrderedDict([
149
150
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
                                padding=3, bias=False)),
151
152
153
154
155
156
157
158
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
159
160
161
162
163
164
165
166
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient
            )
167
168
169
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
170
171
                trans = _Transition(num_input_features=num_features,
                                    num_output_features=num_features // 2)
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
194
195
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
196
197
198
199
        out = self.classifier(out)
        return out


200
def _load_state_dict(model, model_url, progress):
201
    # '.'s are no longer allowed in module names, but previous _DenseLayer
202
203
204
205
206
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
        r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
207
208

    state_dict = load_state_dict_from_url(model_url, progress=progress)
209
210
211
212
213
214
215
216
217
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)


218
219
220
221
222
223
224
225
226
def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress,
              **kwargs):
    model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
    if pretrained:
        _load_state_dict(model, model_urls[arch], progress)
    return model


def densenet121(pretrained=False, progress=True, **kwargs):
Geoff Pleiss's avatar
Geoff Pleiss committed
227
    r"""Densenet-121 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
228
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
229
230
231

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
232
        progress (bool): If True, displays a progress bar of the download to stderr
233
234
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
235
    """
236
    return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress,
237
                     **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
238
239


240
241
def densenet161(pretrained=False, progress=True, **kwargs):
    r"""Densenet-161 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
242
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
243
244
245

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
246
        progress (bool): If True, displays a progress bar of the download to stderr
247
248
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
249
    """
250
    return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress,
251
                     **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
252
253


254
255
def densenet169(pretrained=False, progress=True, **kwargs):
    r"""Densenet-169 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
256
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
257
258
259

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
260
        progress (bool): If True, displays a progress bar of the download to stderr
261
262
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
263
    """
264
    return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress,
265
                     **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
266
267


268
269
def densenet201(pretrained=False, progress=True, **kwargs):
    r"""Densenet-201 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
270
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
271
272
273

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
274
        progress (bool): If True, displays a progress bar of the download to stderr
275
276
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
277
    """
278
    return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress,
279
                     **kwargs)