densenet.py 16.4 KB
Newer Older
1
import re
limm's avatar
limm committed
2
3
4
5
from collections import OrderedDict
from functools import partial
from typing import Any, List, Optional, Tuple

Geoff Pleiss's avatar
Geoff Pleiss committed
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
9
import torch.utils.checkpoint as cp
eellison's avatar
eellison committed
10
from torch import Tensor
Geoff Pleiss's avatar
Geoff Pleiss committed
11

limm's avatar
limm committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface

__all__ = [
    "DenseNet",
    "DenseNet121_Weights",
    "DenseNet161_Weights",
    "DenseNet169_Weights",
    "DenseNet201_Weights",
    "densenet121",
    "densenet161",
    "densenet169",
    "densenet201",
]
Geoff Pleiss's avatar
Geoff Pleiss committed
29
30


eellison's avatar
eellison committed
31
class _DenseLayer(nn.Module):
32
    def __init__(
limm's avatar
limm committed
33
        self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False
34
    ) -> None:
limm's avatar
limm committed
35
36
37
38
39
40
41
42
43
        super().__init__()
        self.norm1 = nn.BatchNorm2d(num_input_features)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)

        self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)

eellison's avatar
eellison committed
44
        self.drop_rate = float(drop_rate)
45
46
        self.memory_efficient = memory_efficient

47
    def bn_function(self, inputs: List[Tensor]) -> Tensor:
eellison's avatar
eellison committed
48
49
50
51
52
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))  # noqa: T484
        return bottleneck_output

    # todo: rewrite when torchscript supports any
53
    def any_requires_grad(self, input: List[Tensor]) -> bool:
eellison's avatar
eellison committed
54
55
56
57
58
59
        for tensor in input:
            if tensor.requires_grad:
                return True
        return False

    @torch.jit.unused  # noqa: T484
60
    def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor:
eellison's avatar
eellison committed
61
        def closure(*inputs):
62
            return self.bn_function(inputs)
eellison's avatar
eellison committed
63

limm's avatar
limm committed
64
        return cp.checkpoint(closure, *input, use_reentrant=False)
eellison's avatar
eellison committed
65
66

    @torch.jit._overload_method  # noqa: F811
limm's avatar
limm committed
67
    def forward(self, input: List[Tensor]) -> Tensor:  # noqa: F811
eellison's avatar
eellison committed
68
69
        pass

70
    @torch.jit._overload_method  # noqa: F811
limm's avatar
limm committed
71
    def forward(self, input: Tensor) -> Tensor:  # noqa: F811
eellison's avatar
eellison committed
72
73
74
75
        pass

    # torchscript does not yet support *args, so we overload method
    # allowing it to take either a List[Tensor] or single Tensor
76
    def forward(self, input: Tensor) -> Tensor:  # noqa: F811
eellison's avatar
eellison committed
77
78
        if isinstance(input, Tensor):
            prev_features = [input]
79
        else:
eellison's avatar
eellison committed
80
81
82
83
84
85
86
87
88
89
            prev_features = input

        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("Memory Efficient not supported in JIT")

            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

90
        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
91
        if self.drop_rate > 0:
limm's avatar
limm committed
92
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
93
        return new_features
94
95


eellison's avatar
eellison committed
96
class _DenseBlock(nn.ModuleDict):
eellison's avatar
eellison committed
97
98
    _version = 2

99
100
101
102
103
104
105
    def __init__(
        self,
        num_layers: int,
        num_input_features: int,
        bn_size: int,
        growth_rate: int,
        drop_rate: float,
limm's avatar
limm committed
106
        memory_efficient: bool = False,
107
    ) -> None:
limm's avatar
limm committed
108
        super().__init__()
109
        for i in range(num_layers):
110
111
112
113
114
115
116
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
            )
limm's avatar
limm committed
117
            self.add_module("denselayer%d" % (i + 1), layer)
118

119
    def forward(self, init_features: Tensor) -> Tensor:
120
        features = [init_features]
eellison's avatar
eellison committed
121
        for name, layer in self.items():
eellison's avatar
eellison committed
122
            new_features = layer(features)
123
124
125
            features.append(new_features)
        return torch.cat(features, 1)

126
127

class _Transition(nn.Sequential):
128
    def __init__(self, num_input_features: int, num_output_features: int) -> None:
limm's avatar
limm committed
129
130
131
132
133
        super().__init__()
        self.norm = nn.BatchNorm2d(num_input_features)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
134
135
136
137


class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
138
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
139
140
141
142
143
144
145
146
147

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
148
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
149
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
150
151
    """

152
153
154
155
156
157
158
159
    def __init__(
        self,
        growth_rate: int = 32,
        block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
        num_init_features: int = 64,
        bn_size: int = 4,
        drop_rate: float = 0,
        num_classes: int = 1000,
limm's avatar
limm committed
160
        memory_efficient: bool = False,
161
    ) -> None:
162

limm's avatar
limm committed
163
164
        super().__init__()
        _log_api_usage_once(self)
165
166

        # First convolution
limm's avatar
limm committed
167
168
169
170
171
172
173
174
175
176
        self.features = nn.Sequential(
            OrderedDict(
                [
                    ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                    ("norm0", nn.BatchNorm2d(num_init_features)),
                    ("relu0", nn.ReLU(inplace=True)),
                    ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
                ]
            )
        )
177
178
179
180

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
181
182
183
184
185
186
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
limm's avatar
limm committed
187
                memory_efficient=memory_efficient,
188
            )
limm's avatar
limm committed
189
            self.features.add_module("denseblock%d" % (i + 1), block)
190
191
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
limm's avatar
limm committed
192
193
                trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
                self.features.add_module("transition%d" % (i + 1), trans)
194
195
196
                num_features = num_features // 2

        # Final batch norm
limm's avatar
limm committed
197
        self.features.add_module("norm5", nn.BatchNorm2d(num_features))
198
199
200
201
202
203
204
205
206
207
208
209
210
211

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

212
    def forward(self, x: Tensor) -> Tensor:
213
214
        features = self.features(x)
        out = F.relu(features, inplace=True)
215
216
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
217
218
219
220
        out = self.classifier(out)
        return out


limm's avatar
limm committed
221
def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None:
222
    # '.'s are no longer allowed in module names, but previous _DenseLayer
223
224
225
226
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
limm's avatar
limm committed
227
228
        r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
    )
229

limm's avatar
limm committed
230
    state_dict = weights.get_state_dict(progress=progress, check_hash=True)
231
232
233
234
235
236
237
238
239
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)


240
241
242
243
def _densenet(
    growth_rate: int,
    block_config: Tuple[int, int, int, int],
    num_init_features: int,
limm's avatar
limm committed
244
    weights: Optional[WeightsEnum],
245
    progress: bool,
limm's avatar
limm committed
246
    **kwargs: Any,
247
) -> DenseNet:
limm's avatar
limm committed
248
249
250
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

251
    model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
limm's avatar
limm committed
252
253
254
255

    if weights is not None:
        _load_state_dict(model=model, weights=weights, progress=progress)

256
257
258
    return model


limm's avatar
limm committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
_COMMON_META = {
    "min_size": (29, 29),
    "categories": _IMAGENET_CATEGORIES,
    "recipe": "https://github.com/pytorch/vision/pull/116",
    "_docs": """These weights are ported from LuaTorch.""",
}


class DenseNet121_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 7978856,
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 74.434,
                    "acc@5": 91.972,
                }
            },
            "_ops": 2.834,
            "_file_size": 30.845,
        },
    )
    DEFAULT = IMAGENET1K_V1


class DenseNet161_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 28681000,
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 77.138,
                    "acc@5": 93.560,
                }
            },
            "_ops": 7.728,
            "_file_size": 110.369,
        },
    )
    DEFAULT = IMAGENET1K_V1


class DenseNet169_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 14149480,
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 75.600,
                    "acc@5": 92.806,
                }
            },
            "_ops": 3.36,
            "_file_size": 54.708,
        },
    )
    DEFAULT = IMAGENET1K_V1


class DenseNet201_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/densenet201-c1103571.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 20013928,
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 76.896,
                    "acc@5": 93.370,
                }
            },
            "_ops": 4.291,
            "_file_size": 77.373,
        },
    )
    DEFAULT = IMAGENET1K_V1


@register_model()
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
Geoff Pleiss's avatar
Geoff Pleiss committed
350
    r"""Densenet-121 model from
limm's avatar
limm committed
351
    `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
Geoff Pleiss's avatar
Geoff Pleiss committed
352
353

    Args:
limm's avatar
limm committed
354
355
356
357
358
359
360
361
362
363
364
365
366
        weights (:class:`~torchvision.models.DenseNet121_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.DenseNet121_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.DenseNet121_Weights
        :members:
Geoff Pleiss's avatar
Geoff Pleiss committed
367
    """
limm's avatar
limm committed
368
369
370
    weights = DenseNet121_Weights.verify(weights)

    return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
371
372


limm's avatar
limm committed
373
374
375
@register_model()
@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1))
def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
376
    r"""Densenet-161 model from
limm's avatar
limm committed
377
    `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
Geoff Pleiss's avatar
Geoff Pleiss committed
378
379

    Args:
limm's avatar
limm committed
380
381
382
383
384
385
386
387
388
389
390
391
392
        weights (:class:`~torchvision.models.DenseNet161_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.DenseNet161_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.DenseNet161_Weights
        :members:
Geoff Pleiss's avatar
Geoff Pleiss committed
393
    """
limm's avatar
limm committed
394
    weights = DenseNet161_Weights.verify(weights)
Geoff Pleiss's avatar
Geoff Pleiss committed
395

limm's avatar
limm committed
396
    return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
397

limm's avatar
limm committed
398
399
400
401

@register_model()
@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1))
def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
402
    r"""Densenet-169 model from
limm's avatar
limm committed
403
    `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
Geoff Pleiss's avatar
Geoff Pleiss committed
404
405

    Args:
limm's avatar
limm committed
406
407
408
409
410
411
412
413
414
415
416
417
418
        weights (:class:`~torchvision.models.DenseNet169_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.DenseNet169_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.DenseNet169_Weights
        :members:
Geoff Pleiss's avatar
Geoff Pleiss committed
419
    """
limm's avatar
limm committed
420
421
422
    weights = DenseNet169_Weights.verify(weights)

    return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
423
424


limm's avatar
limm committed
425
426
427
@register_model()
@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1))
def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
428
    r"""Densenet-201 model from
limm's avatar
limm committed
429
    `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
Geoff Pleiss's avatar
Geoff Pleiss committed
430
431

    Args:
limm's avatar
limm committed
432
433
434
435
436
437
438
439
440
441
442
443
444
        weights (:class:`~torchvision.models.DenseNet201_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.DenseNet201_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.DenseNet201_Weights
        :members:
Geoff Pleiss's avatar
Geoff Pleiss committed
445
    """
limm's avatar
limm committed
446
447
448
    weights = DenseNet201_Weights.verify(weights)

    return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)