densenet.py 8.32 KB
Newer Older
1
import re
Geoff Pleiss's avatar
Geoff Pleiss committed
2
3
4
import torch
import torch.nn as nn
import torch.nn.functional as F
5
from .utils import load_state_dict_from_url
Geoff Pleiss's avatar
Geoff Pleiss committed
6
7
8
9
10
from collections import OrderedDict

__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']

model_urls = {
11
12
13
14
    'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
Geoff Pleiss's avatar
Geoff Pleiss committed
15
16
17
}


18
19
20
21
22
23
class _DenseLayer(nn.Sequential):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
24
25
                                           growth_rate, kernel_size=1, stride=1,
                                           bias=False)),
26
27
28
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
29
30
                                           kernel_size=3, stride=1, padding=1,
                                           bias=False)),
31
32
33
34
35
        self.drop_rate = drop_rate

    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        if self.drop_rate > 0:
36
37
            new_features = F.dropout(new_features, p=self.drop_rate,
                                     training=self.training)
38
39
40
41
42
43
44
        return torch.cat([x, new_features], 1)


class _DenseBlock(nn.Sequential):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
45
46
            layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate,
                                bn_size, drop_rate)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
            self.add_module('denselayer%d' % (i + 1), layer)


class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
    """

    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
                 num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):

        super(DenseNet, self).__init__()

        # First convolution
        self.features = nn.Sequential(OrderedDict([
81
82
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
                                padding=3, bias=False)),
83
84
85
86
87
88
89
90
91
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
92
93
                                bn_size=bn_size, growth_rate=growth_rate,
                                drop_rate=drop_rate)
94
95
96
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
97
98
                trans = _Transition(num_input_features=num_features,
                                    num_output_features=num_features // 2)
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
121
        out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1)
122
123
124
125
        out = self.classifier(out)
        return out


126
def _load_state_dict(model, model_url, progress):
127
    # '.'s are no longer allowed in module names, but previous _DenseLayer
128
129
130
131
132
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
        r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
133
134

    state_dict = load_state_dict_from_url(model_url, progress=progress)
135
136
137
138
139
140
141
142
143
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)


144
145
146
147
148
149
150
151
152
def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress,
              **kwargs):
    model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
    if pretrained:
        _load_state_dict(model, model_urls[arch], progress)
    return model


def densenet121(pretrained=False, progress=True, **kwargs):
Geoff Pleiss's avatar
Geoff Pleiss committed
153
    r"""Densenet-121 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
154
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
155
156
157

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
158
        progress (bool): If True, displays a progress bar of the download to stderr
Geoff Pleiss's avatar
Geoff Pleiss committed
159
    """
160
    return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress,
161
                     **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
162
163


164
165
def densenet161(pretrained=False, progress=True, **kwargs):
    r"""Densenet-161 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
166
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
167
168
169

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
170
        progress (bool): If True, displays a progress bar of the download to stderr
Geoff Pleiss's avatar
Geoff Pleiss committed
171
    """
172
    return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress,
173
                     **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
174
175


176
177
def densenet169(pretrained=False, progress=True, **kwargs):
    r"""Densenet-169 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
178
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
179
180
181

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
182
        progress (bool): If True, displays a progress bar of the download to stderr
Geoff Pleiss's avatar
Geoff Pleiss committed
183
    """
184
    return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress,
185
                     **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
186
187


188
189
def densenet201(pretrained=False, progress=True, **kwargs):
    r"""Densenet-201 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
190
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
191
192
193

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
194
        progress (bool): If True, displays a progress bar of the download to stderr
Geoff Pleiss's avatar
Geoff Pleiss committed
195
    """
196
    return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress,
197
                     **kwargs)