densenet.py 12.3 KB
Newer Older
1
import re
Geoff Pleiss's avatar
Geoff Pleiss committed
2
3
4
import torch
import torch.nn as nn
import torch.nn.functional as F
5
import torch.utils.checkpoint as cp
Geoff Pleiss's avatar
Geoff Pleiss committed
6
from collections import OrderedDict
7
from .utils import load_state_dict_from_url
eellison's avatar
eellison committed
8
from torch import Tensor
9
from typing import Any, List, Tuple
10

Geoff Pleiss's avatar
Geoff Pleiss committed
11
12
13
14

__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']

model_urls = {
15
16
17
18
    'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
Geoff Pleiss's avatar
Geoff Pleiss committed
19
20
21
}


eellison's avatar
eellison committed
22
class _DenseLayer(nn.Module):
23
24
25
26
27
28
29
30
    def __init__(
        self,
        num_input_features: int,
        growth_rate: int,
        bn_size: int,
        drop_rate: float,
        memory_efficient: bool = False
    ) -> None:
31
        super(_DenseLayer, self).__init__()
32
33
34
35
36
        self.norm1: nn.BatchNorm2d
        self.add_module('norm1', nn.BatchNorm2d(num_input_features))
        self.relu1: nn.ReLU
        self.add_module('relu1', nn.ReLU(inplace=True))
        self.conv1: nn.Conv2d
37
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
38
                                           growth_rate, kernel_size=1, stride=1,
39
40
41
42
43
44
                                           bias=False))
        self.norm2: nn.BatchNorm2d
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate))
        self.relu2: nn.ReLU
        self.add_module('relu2', nn.ReLU(inplace=True))
        self.conv2: nn.Conv2d
45
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
46
                                           kernel_size=3, stride=1, padding=1,
47
                                           bias=False))
eellison's avatar
eellison committed
48
        self.drop_rate = float(drop_rate)
49
50
        self.memory_efficient = memory_efficient

51
    def bn_function(self, inputs: List[Tensor]) -> Tensor:
eellison's avatar
eellison committed
52
53
54
55
56
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))  # noqa: T484
        return bottleneck_output

    # todo: rewrite when torchscript supports any
57
    def any_requires_grad(self, input: List[Tensor]) -> bool:
eellison's avatar
eellison committed
58
59
60
61
62
63
        for tensor in input:
            if tensor.requires_grad:
                return True
        return False

    @torch.jit.unused  # noqa: T484
64
    def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor:
eellison's avatar
eellison committed
65
        def closure(*inputs):
66
            return self.bn_function(inputs)
eellison's avatar
eellison committed
67

68
        return cp.checkpoint(closure, *input)
eellison's avatar
eellison committed
69
70

    @torch.jit._overload_method  # noqa: F811
71
    def forward(self, input: List[Tensor]) -> Tensor:
eellison's avatar
eellison committed
72
73
        pass

74
75
    @torch.jit._overload_method  # type: ignore[no-redef] # noqa: F811
    def forward(self, input: Tensor) -> Tensor:
eellison's avatar
eellison committed
76
77
78
79
        pass

    # torchscript does not yet support *args, so we overload method
    # allowing it to take either a List[Tensor] or single Tensor
80
    def forward(self, input: Tensor) -> Tensor:  # type: ignore[no-redef] # noqa: F811
eellison's avatar
eellison committed
81
82
        if isinstance(input, Tensor):
            prev_features = [input]
83
        else:
eellison's avatar
eellison committed
84
85
86
87
88
89
90
91
92
93
            prev_features = input

        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("Memory Efficient not supported in JIT")

            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

94
        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
95
        if self.drop_rate > 0:
96
97
            new_features = F.dropout(new_features, p=self.drop_rate,
                                     training=self.training)
98
        return new_features
99
100


eellison's avatar
eellison committed
101
class _DenseBlock(nn.ModuleDict):
eellison's avatar
eellison committed
102
103
    _version = 2

104
105
106
107
108
109
110
111
112
    def __init__(
        self,
        num_layers: int,
        num_input_features: int,
        bn_size: int,
        growth_rate: int,
        drop_rate: float,
        memory_efficient: bool = False
    ) -> None:
113
114
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
115
116
117
118
119
120
121
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
            )
eellison's avatar
eellison committed
122
            self.add_module('denselayer%d' % (i + 1), layer)
123

124
    def forward(self, init_features: Tensor) -> Tensor:  # type: ignore[override]
125
        features = [init_features]
eellison's avatar
eellison committed
126
        for name, layer in self.items():
eellison's avatar
eellison committed
127
            new_features = layer(features)
128
129
130
            features.append(new_features)
        return torch.cat(features, 1)

131
132

class _Transition(nn.Sequential):
133
    def __init__(self, num_input_features: int, num_output_features: int) -> None:
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
154
155
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
156
157
    """

158
159
160
161
162
163
164
165
166
167
    def __init__(
        self,
        growth_rate: int = 32,
        block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
        num_init_features: int = 64,
        bn_size: int = 4,
        drop_rate: float = 0,
        num_classes: int = 1000,
        memory_efficient: bool = False
    ) -> None:
168
169
170
171
172

        super(DenseNet, self).__init__()

        # First convolution
        self.features = nn.Sequential(OrderedDict([
173
174
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
                                padding=3, bias=False)),
175
176
177
178
179
180
181
182
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
183
184
185
186
187
188
189
190
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient
            )
191
192
193
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
194
195
                trans = _Transition(num_input_features=num_features,
                                    num_output_features=num_features // 2)
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

215
    def forward(self, x: Tensor) -> Tensor:
216
217
        features = self.features(x)
        out = F.relu(features, inplace=True)
218
219
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
220
221
222
223
        out = self.classifier(out)
        return out


224
def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None:
225
    # '.'s are no longer allowed in module names, but previous _DenseLayer
226
227
228
229
230
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
        r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
231
232

    state_dict = load_state_dict_from_url(model_url, progress=progress)
233
234
235
236
237
238
239
240
241
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)


242
243
244
245
246
247
248
249
250
def _densenet(
    arch: str,
    growth_rate: int,
    block_config: Tuple[int, int, int, int],
    num_init_features: int,
    pretrained: bool,
    progress: bool,
    **kwargs: Any
) -> DenseNet:
251
252
253
254
255
256
    model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
    if pretrained:
        _load_state_dict(model, model_urls[arch], progress)
    return model


257
def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
Geoff Pleiss's avatar
Geoff Pleiss committed
258
    r"""Densenet-121 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
259
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
260
261
262

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
263
        progress (bool): If True, displays a progress bar of the download to stderr
264
265
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
266
    """
267
    return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress,
268
                     **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
269
270


271
def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
272
    r"""Densenet-161 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
273
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
274
275
276

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
277
        progress (bool): If True, displays a progress bar of the download to stderr
278
279
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
280
    """
281
    return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress,
282
                     **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
283
284


285
def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
286
    r"""Densenet-169 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
287
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
288
289
290

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
291
        progress (bool): If True, displays a progress bar of the download to stderr
292
293
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
294
    """
295
    return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress,
296
                     **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
297
298


299
def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
300
    r"""Densenet-201 model from
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
301
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
302
303
304

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
305
        progress (bool): If True, displays a progress bar of the download to stderr
306
307
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
Geoff Pleiss's avatar
Geoff Pleiss committed
308
    """
309
    return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress,
310
                     **kwargs)