densenet.py 14.6 KB
Newer Older
1
import re
2
from collections import OrderedDict
3
4
from functools import partial
from typing import Any, List, Optional, Tuple
5

Geoff Pleiss's avatar
Geoff Pleiss committed
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
9
import torch.utils.checkpoint as cp
eellison's avatar
eellison committed
10
from torch import Tensor
11

12
from ..transforms._presets import ImageClassification
13
from ..utils import _log_api_usage_once
14
15
16
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
17

Geoff Pleiss's avatar
Geoff Pleiss committed
18

19
20
21
22
23
24
25
26
27
28
29
__all__ = [
    "DenseNet",
    "DenseNet121_Weights",
    "DenseNet161_Weights",
    "DenseNet169_Weights",
    "DenseNet201_Weights",
    "densenet121",
    "densenet161",
    "densenet169",
    "densenet201",
]
Geoff Pleiss's avatar
Geoff Pleiss committed
30
31


eellison's avatar
eellison committed
32
class _DenseLayer(nn.Module):
33
    def __init__(
34
        self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False
35
    ) -> None:
36
        super().__init__()
37
        self.norm1: nn.BatchNorm2d
38
        self.add_module("norm1", nn.BatchNorm2d(num_input_features))
39
        self.relu1: nn.ReLU
40
        self.add_module("relu1", nn.ReLU(inplace=True))
41
        self.conv1: nn.Conv2d
42
43
44
        self.add_module(
            "conv1", nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)
        )
45
        self.norm2: nn.BatchNorm2d
46
        self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate))
47
        self.relu2: nn.ReLU
48
        self.add_module("relu2", nn.ReLU(inplace=True))
49
        self.conv2: nn.Conv2d
50
51
52
        self.add_module(
            "conv2", nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
        )
eellison's avatar
eellison committed
53
        self.drop_rate = float(drop_rate)
54
55
        self.memory_efficient = memory_efficient

56
    def bn_function(self, inputs: List[Tensor]) -> Tensor:
eellison's avatar
eellison committed
57
58
59
60
61
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))  # noqa: T484
        return bottleneck_output

    # todo: rewrite when torchscript supports any
62
    def any_requires_grad(self, input: List[Tensor]) -> bool:
eellison's avatar
eellison committed
63
64
65
66
67
68
        for tensor in input:
            if tensor.requires_grad:
                return True
        return False

    @torch.jit.unused  # noqa: T484
69
    def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor:
eellison's avatar
eellison committed
70
        def closure(*inputs):
71
            return self.bn_function(inputs)
eellison's avatar
eellison committed
72

73
        return cp.checkpoint(closure, *input)
eellison's avatar
eellison committed
74
75

    @torch.jit._overload_method  # noqa: F811
76
    def forward(self, input: List[Tensor]) -> Tensor:  # noqa: F811
eellison's avatar
eellison committed
77
78
        pass

79
    @torch.jit._overload_method  # noqa: F811
80
    def forward(self, input: Tensor) -> Tensor:  # noqa: F811
eellison's avatar
eellison committed
81
82
83
84
        pass

    # torchscript does not yet support *args, so we overload method
    # allowing it to take either a List[Tensor] or single Tensor
85
    def forward(self, input: Tensor) -> Tensor:  # noqa: F811
eellison's avatar
eellison committed
86
87
        if isinstance(input, Tensor):
            prev_features = [input]
88
        else:
eellison's avatar
eellison committed
89
90
91
92
93
94
95
96
97
98
            prev_features = input

        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("Memory Efficient not supported in JIT")

            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

99
        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
100
        if self.drop_rate > 0:
101
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
102
        return new_features
103
104


eellison's avatar
eellison committed
105
class _DenseBlock(nn.ModuleDict):
eellison's avatar
eellison committed
106
107
    _version = 2

108
109
110
111
112
113
114
    def __init__(
        self,
        num_layers: int,
        num_input_features: int,
        bn_size: int,
        growth_rate: int,
        drop_rate: float,
115
        memory_efficient: bool = False,
116
    ) -> None:
117
        super().__init__()
118
        for i in range(num_layers):
119
120
121
122
123
124
125
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
            )
126
            self.add_module("denselayer%d" % (i + 1), layer)
127

128
    def forward(self, init_features: Tensor) -> Tensor:
129
        features = [init_features]
eellison's avatar
eellison committed
130
        for name, layer in self.items():
eellison's avatar
eellison committed
131
            new_features = layer(features)
132
133
134
            features.append(new_features)
        return torch.cat(features, 1)

135
136

class _Transition(nn.Sequential):
137
    def __init__(self, num_input_features: int, num_output_features: int) -> None:
138
        super().__init__()
139
140
141
142
        self.add_module("norm", nn.BatchNorm2d(num_input_features))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
        self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2))
143
144
145
146


class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
147
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
148
149
150
151
152
153
154
155
156

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
157
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
158
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
159
160
    """

161
162
163
164
165
166
167
168
    def __init__(
        self,
        growth_rate: int = 32,
        block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
        num_init_features: int = 64,
        bn_size: int = 4,
        drop_rate: float = 0,
        num_classes: int = 1000,
169
        memory_efficient: bool = False,
170
    ) -> None:
171

172
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
173
        _log_api_usage_once(self)
174
175

        # First convolution
176
177
178
179
180
181
182
183
184
185
        self.features = nn.Sequential(
            OrderedDict(
                [
                    ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                    ("norm0", nn.BatchNorm2d(num_init_features)),
                    ("relu0", nn.ReLU(inplace=True)),
                    ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
                ]
            )
        )
186
187
188
189

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
190
191
192
193
194
195
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
196
                memory_efficient=memory_efficient,
197
            )
198
            self.features.add_module("denseblock%d" % (i + 1), block)
199
200
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
201
202
                trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
                self.features.add_module("transition%d" % (i + 1), trans)
203
204
205
                num_features = num_features // 2

        # Final batch norm
206
        self.features.add_module("norm5", nn.BatchNorm2d(num_features))
207
208
209
210
211
212
213
214
215
216
217
218
219
220

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

221
    def forward(self, x: Tensor) -> Tensor:
222
223
        features = self.features(x)
        out = F.relu(features, inplace=True)
224
225
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
226
227
228
229
        out = self.classifier(out)
        return out


230
def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None:
231
    # '.'s are no longer allowed in module names, but previous _DenseLayer
232
233
234
235
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
236
237
        r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
    )
238

239
    state_dict = weights.get_state_dict(progress=progress)
240
241
242
243
244
245
246
247
248
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)


249
250
251
252
def _densenet(
    growth_rate: int,
    block_config: Tuple[int, int, int, int],
    num_init_features: int,
253
    weights: Optional[WeightsEnum],
254
    progress: bool,
255
    **kwargs: Any,
256
) -> DenseNet:
257
258
259
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

260
    model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
261
262
263
264

    if weights is not None:
        _load_state_dict(model=model, weights=weights, progress=progress)

265
266
267
    return model


268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
_COMMON_META = {
    "min_size": (29, 29),
    "categories": _IMAGENET_CATEGORIES,
    "recipe": "https://github.com/pytorch/vision/pull/116",
}


class DenseNet121_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 7978856,
            "acc@1": 74.434,
            "acc@5": 91.972,
        },
    )
    DEFAULT = IMAGENET1K_V1


class DenseNet161_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 28681000,
            "acc@1": 77.138,
            "acc@5": 93.560,
        },
    )
    DEFAULT = IMAGENET1K_V1


class DenseNet169_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 14149480,
            "acc@1": 75.600,
            "acc@5": 92.806,
        },
    )
    DEFAULT = IMAGENET1K_V1


class DenseNet201_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/densenet201-c1103571.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 20013928,
            "acc@1": 76.896,
            "acc@5": 93.370,
        },
    )
    DEFAULT = IMAGENET1K_V1


@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
Geoff Pleiss's avatar
Geoff Pleiss committed
333
    r"""Densenet-121 model from
334
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
335
    The required minimum input size of the model is 29x29.
Geoff Pleiss's avatar
Geoff Pleiss committed
336
337

    Args:
338
        weights (DenseNet121_Weights, optional): The pretrained weights for the model
339
        progress (bool): If True, displays a progress bar of the download to stderr
340
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
341
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
Geoff Pleiss's avatar
Geoff Pleiss committed
342
    """
343
    weights = DenseNet121_Weights.verify(weights)
Geoff Pleiss's avatar
Geoff Pleiss committed
344

345
    return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
346

347
348
349

@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1))
def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
350
    r"""Densenet-161 model from
351
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
352
    The required minimum input size of the model is 29x29.
Geoff Pleiss's avatar
Geoff Pleiss committed
353
354

    Args:
355
        weights (DenseNet161_Weights, optional): The pretrained weights for the model
356
        progress (bool): If True, displays a progress bar of the download to stderr
357
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
358
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
Geoff Pleiss's avatar
Geoff Pleiss committed
359
    """
360
361
362
    weights = DenseNet161_Weights.verify(weights)

    return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
363
364


365
366
@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1))
def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
367
    r"""Densenet-169 model from
368
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
369
    The required minimum input size of the model is 29x29.
Geoff Pleiss's avatar
Geoff Pleiss committed
370
371

    Args:
372
        weights (DenseNet169_Weights, optional): The pretrained weights for the model
373
        progress (bool): If True, displays a progress bar of the download to stderr
374
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
375
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
Geoff Pleiss's avatar
Geoff Pleiss committed
376
    """
377
    weights = DenseNet169_Weights.verify(weights)
Geoff Pleiss's avatar
Geoff Pleiss committed
378

379
    return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
380

381
382
383

@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1))
def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
384
    r"""Densenet-201 model from
385
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
386
    The required minimum input size of the model is 29x29.
Geoff Pleiss's avatar
Geoff Pleiss committed
387
388

    Args:
389
        weights (DenseNet201_Weights, optional): The pretrained weights for the model
390
        progress (bool): If True, displays a progress bar of the download to stderr
391
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
392
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
Geoff Pleiss's avatar
Geoff Pleiss committed
393
    """
394
395
396
    weights = DenseNet201_Weights.verify(weights)

    return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)