densenet.py 12.2 KB
Newer Older
1
import re
2
3
4
from collections import OrderedDict
from typing import Any, List, Tuple

Geoff Pleiss's avatar
Geoff Pleiss committed
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F
8
import torch.utils.checkpoint as cp
eellison's avatar
eellison committed
9
from torch import Tensor
10
11

from .._internally_replaced_utils import load_state_dict_from_url
12
from ..utils import _log_api_usage_once
13

Geoff Pleiss's avatar
Geoff Pleiss committed
14

15
__all__ = ["DenseNet", "densenet121", "densenet169", "densenet201", "densenet161"]
Geoff Pleiss's avatar
Geoff Pleiss committed
16
17

model_urls = {
18
19
20
21
    "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
    "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
    "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
    "densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth",
Geoff Pleiss's avatar
Geoff Pleiss committed
22
23
24
}


eellison's avatar
eellison committed
25
class _DenseLayer(nn.Module):
26
    def __init__(
27
        self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False
28
    ) -> None:
29
        super().__init__()
30
        self.norm1: nn.BatchNorm2d
31
        self.add_module("norm1", nn.BatchNorm2d(num_input_features))
32
        self.relu1: nn.ReLU
33
        self.add_module("relu1", nn.ReLU(inplace=True))
34
        self.conv1: nn.Conv2d
35
36
37
        self.add_module(
            "conv1", nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)
        )
38
        self.norm2: nn.BatchNorm2d
39
        self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate))
40
        self.relu2: nn.ReLU
41
        self.add_module("relu2", nn.ReLU(inplace=True))
42
        self.conv2: nn.Conv2d
43
44
45
        self.add_module(
            "conv2", nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
        )
eellison's avatar
eellison committed
46
        self.drop_rate = float(drop_rate)
47
48
        self.memory_efficient = memory_efficient

49
    def bn_function(self, inputs: List[Tensor]) -> Tensor:
eellison's avatar
eellison committed
50
51
52
53
54
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))  # noqa: T484
        return bottleneck_output

    # todo: rewrite when torchscript supports any
55
    def any_requires_grad(self, input: List[Tensor]) -> bool:
eellison's avatar
eellison committed
56
57
58
59
60
61
        for tensor in input:
            if tensor.requires_grad:
                return True
        return False

    @torch.jit.unused  # noqa: T484
62
    def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor:
eellison's avatar
eellison committed
63
        def closure(*inputs):
64
            return self.bn_function(inputs)
eellison's avatar
eellison committed
65

66
        return cp.checkpoint(closure, *input)
eellison's avatar
eellison committed
67
68

    @torch.jit._overload_method  # noqa: F811
69
    def forward(self, input: List[Tensor]) -> Tensor:  # noqa: F811
eellison's avatar
eellison committed
70
71
        pass

72
    @torch.jit._overload_method  # noqa: F811
73
    def forward(self, input: Tensor) -> Tensor:  # noqa: F811
eellison's avatar
eellison committed
74
75
76
77
        pass

    # torchscript does not yet support *args, so we overload method
    # allowing it to take either a List[Tensor] or single Tensor
78
    def forward(self, input: Tensor) -> Tensor:  # noqa: F811
eellison's avatar
eellison committed
79
80
        if isinstance(input, Tensor):
            prev_features = [input]
81
        else:
eellison's avatar
eellison committed
82
83
84
85
86
87
88
89
90
91
            prev_features = input

        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("Memory Efficient not supported in JIT")

            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

92
        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
93
        if self.drop_rate > 0:
94
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
95
        return new_features
96
97


eellison's avatar
eellison committed
98
class _DenseBlock(nn.ModuleDict):
eellison's avatar
eellison committed
99
100
    _version = 2

101
102
103
104
105
106
107
    def __init__(
        self,
        num_layers: int,
        num_input_features: int,
        bn_size: int,
        growth_rate: int,
        drop_rate: float,
108
        memory_efficient: bool = False,
109
    ) -> None:
110
        super().__init__()
111
        for i in range(num_layers):
112
113
114
115
116
117
118
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
            )
119
            self.add_module("denselayer%d" % (i + 1), layer)
120

121
    def forward(self, init_features: Tensor) -> Tensor:
122
        features = [init_features]
eellison's avatar
eellison committed
123
        for name, layer in self.items():
eellison's avatar
eellison committed
124
            new_features = layer(features)
125
126
127
            features.append(new_features)
        return torch.cat(features, 1)

128
129

class _Transition(nn.Sequential):
130
    def __init__(self, num_input_features: int, num_output_features: int) -> None:
131
        super().__init__()
132
133
134
135
        self.add_module("norm", nn.BatchNorm2d(num_input_features))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
        self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2))
136
137
138
139


class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
140
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
141
142
143
144
145
146
147
148
149

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
150
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
151
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
152
153
    """

154
155
156
157
158
159
160
161
    def __init__(
        self,
        growth_rate: int = 32,
        block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
        num_init_features: int = 64,
        bn_size: int = 4,
        drop_rate: float = 0,
        num_classes: int = 1000,
162
        memory_efficient: bool = False,
163
    ) -> None:
164

165
        super().__init__()
166
        _log_api_usage_once(self)
167
168

        # First convolution
169
170
171
172
173
174
175
176
177
178
        self.features = nn.Sequential(
            OrderedDict(
                [
                    ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                    ("norm0", nn.BatchNorm2d(num_init_features)),
                    ("relu0", nn.ReLU(inplace=True)),
                    ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
                ]
            )
        )
179
180
181
182

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
183
184
185
186
187
188
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
189
                memory_efficient=memory_efficient,
190
            )
191
            self.features.add_module("denseblock%d" % (i + 1), block)
192
193
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
194
195
                trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
                self.features.add_module("transition%d" % (i + 1), trans)
196
197
198
                num_features = num_features // 2

        # Final batch norm
199
        self.features.add_module("norm5", nn.BatchNorm2d(num_features))
200
201
202
203
204
205
206
207
208
209
210
211
212
213

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

214
    def forward(self, x: Tensor) -> Tensor:
215
216
        features = self.features(x)
        out = F.relu(features, inplace=True)
217
218
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
219
220
221
222
        out = self.classifier(out)
        return out


223
def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None:
224
    # '.'s are no longer allowed in module names, but previous _DenseLayer
225
226
227
228
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
229
230
        r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
    )
231
232

    state_dict = load_state_dict_from_url(model_url, progress=progress)
233
234
235
236
237
238
239
240
241
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)


242
243
244
245
246
247
248
def _densenet(
    arch: str,
    growth_rate: int,
    block_config: Tuple[int, int, int, int],
    num_init_features: int,
    pretrained: bool,
    progress: bool,
249
    **kwargs: Any,
250
) -> DenseNet:
251
252
253
254
255
256
    model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
    if pretrained:
        _load_state_dict(model, model_urls[arch], progress)
    return model


257
def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
Geoff Pleiss's avatar
Geoff Pleiss committed
258
    r"""Densenet-121 model from
259
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
260
    The required minimum input size of the model is 29x29.
Geoff Pleiss's avatar
Geoff Pleiss committed
261
262
263

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
264
        progress (bool): If True, displays a progress bar of the download to stderr
265
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
266
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
Geoff Pleiss's avatar
Geoff Pleiss committed
267
    """
268
    return _densenet("densenet121", 32, (6, 12, 24, 16), 64, pretrained, progress, **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
269
270


271
def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
272
    r"""Densenet-161 model from
273
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
274
    The required minimum input size of the model is 29x29.
Geoff Pleiss's avatar
Geoff Pleiss committed
275
276
277

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
278
        progress (bool): If True, displays a progress bar of the download to stderr
279
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
280
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
Geoff Pleiss's avatar
Geoff Pleiss committed
281
    """
282
    return _densenet("densenet161", 48, (6, 12, 36, 24), 96, pretrained, progress, **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
283
284


285
def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
286
    r"""Densenet-169 model from
287
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
288
    The required minimum input size of the model is 29x29.
Geoff Pleiss's avatar
Geoff Pleiss committed
289
290
291

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
292
        progress (bool): If True, displays a progress bar of the download to stderr
293
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
294
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
Geoff Pleiss's avatar
Geoff Pleiss committed
295
    """
296
    return _densenet("densenet169", 32, (6, 12, 32, 32), 64, pretrained, progress, **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
297
298


299
def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
300
    r"""Densenet-201 model from
301
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
302
    The required minimum input size of the model is 29x29.
Geoff Pleiss's avatar
Geoff Pleiss committed
303
304
305

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
306
        progress (bool): If True, displays a progress bar of the download to stderr
307
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
308
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
Geoff Pleiss's avatar
Geoff Pleiss committed
309
    """
310
    return _densenet("densenet201", 32, (6, 12, 48, 32), 64, pretrained, progress, **kwargs)