"vscode:/vscode.git/clone" did not exist on "eb09070095fc81e41c20687e5ba6552738034c51"
densenet.py 14.7 KB
Newer Older
1
import re
2
from collections import OrderedDict
3
4
from functools import partial
from typing import Any, List, Optional, Tuple
5

Geoff Pleiss's avatar
Geoff Pleiss committed
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
9
import torch.utils.checkpoint as cp
eellison's avatar
eellison committed
10
from torch import Tensor
11

12
from ..transforms._presets import ImageClassification
13
from ..utils import _log_api_usage_once
14
15
16
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
17

Geoff Pleiss's avatar
Geoff Pleiss committed
18

19
20
21
22
23
24
25
26
27
28
29
__all__ = [
    "DenseNet",
    "DenseNet121_Weights",
    "DenseNet161_Weights",
    "DenseNet169_Weights",
    "DenseNet201_Weights",
    "densenet121",
    "densenet161",
    "densenet169",
    "densenet201",
]
Geoff Pleiss's avatar
Geoff Pleiss committed
30
31


eellison's avatar
eellison committed
32
class _DenseLayer(nn.Module):
33
    def __init__(
34
        self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False
35
    ) -> None:
36
        super().__init__()
37
        self.norm1: nn.BatchNorm2d
38
        self.add_module("norm1", nn.BatchNorm2d(num_input_features))
39
        self.relu1: nn.ReLU
40
        self.add_module("relu1", nn.ReLU(inplace=True))
41
        self.conv1: nn.Conv2d
42
43
44
        self.add_module(
            "conv1", nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)
        )
45
        self.norm2: nn.BatchNorm2d
46
        self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate))
47
        self.relu2: nn.ReLU
48
        self.add_module("relu2", nn.ReLU(inplace=True))
49
        self.conv2: nn.Conv2d
50
51
52
        self.add_module(
            "conv2", nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
        )
eellison's avatar
eellison committed
53
        self.drop_rate = float(drop_rate)
54
55
        self.memory_efficient = memory_efficient

56
    def bn_function(self, inputs: List[Tensor]) -> Tensor:
eellison's avatar
eellison committed
57
58
59
60
61
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))  # noqa: T484
        return bottleneck_output

    # todo: rewrite when torchscript supports any
62
    def any_requires_grad(self, input: List[Tensor]) -> bool:
eellison's avatar
eellison committed
63
64
65
66
67
68
        for tensor in input:
            if tensor.requires_grad:
                return True
        return False

    @torch.jit.unused  # noqa: T484
69
    def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor:
eellison's avatar
eellison committed
70
        def closure(*inputs):
71
            return self.bn_function(inputs)
eellison's avatar
eellison committed
72

73
        return cp.checkpoint(closure, *input)
eellison's avatar
eellison committed
74
75

    @torch.jit._overload_method  # noqa: F811
76
    def forward(self, input: List[Tensor]) -> Tensor:  # noqa: F811
eellison's avatar
eellison committed
77
78
        pass

79
    @torch.jit._overload_method  # noqa: F811
80
    def forward(self, input: Tensor) -> Tensor:  # noqa: F811
eellison's avatar
eellison committed
81
82
83
84
        pass

    # torchscript does not yet support *args, so we overload method
    # allowing it to take either a List[Tensor] or single Tensor
85
    def forward(self, input: Tensor) -> Tensor:  # noqa: F811
eellison's avatar
eellison committed
86
87
        if isinstance(input, Tensor):
            prev_features = [input]
88
        else:
eellison's avatar
eellison committed
89
90
91
92
93
94
95
96
97
98
            prev_features = input

        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("Memory Efficient not supported in JIT")

            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

99
        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
100
        if self.drop_rate > 0:
101
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
102
        return new_features
103
104


eellison's avatar
eellison committed
105
class _DenseBlock(nn.ModuleDict):
eellison's avatar
eellison committed
106
107
    _version = 2

108
109
110
111
112
113
114
    def __init__(
        self,
        num_layers: int,
        num_input_features: int,
        bn_size: int,
        growth_rate: int,
        drop_rate: float,
115
        memory_efficient: bool = False,
116
    ) -> None:
117
        super().__init__()
118
        for i in range(num_layers):
119
120
121
122
123
124
125
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
            )
126
            self.add_module("denselayer%d" % (i + 1), layer)
127

128
    def forward(self, init_features: Tensor) -> Tensor:
129
        features = [init_features]
eellison's avatar
eellison committed
130
        for name, layer in self.items():
eellison's avatar
eellison committed
131
            new_features = layer(features)
132
133
134
            features.append(new_features)
        return torch.cat(features, 1)

135
136

class _Transition(nn.Sequential):
137
    def __init__(self, num_input_features: int, num_output_features: int) -> None:
138
        super().__init__()
139
140
141
142
        self.add_module("norm", nn.BatchNorm2d(num_input_features))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
        self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2))
143
144
145
146


class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
147
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
148
149
150
151
152
153
154
155
156

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
157
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
158
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
159
160
    """

161
162
163
164
165
166
167
168
    def __init__(
        self,
        growth_rate: int = 32,
        block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
        num_init_features: int = 64,
        bn_size: int = 4,
        drop_rate: float = 0,
        num_classes: int = 1000,
169
        memory_efficient: bool = False,
170
    ) -> None:
171

172
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
173
        _log_api_usage_once(self)
174
175

        # First convolution
176
177
178
179
180
181
182
183
184
185
        self.features = nn.Sequential(
            OrderedDict(
                [
                    ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                    ("norm0", nn.BatchNorm2d(num_init_features)),
                    ("relu0", nn.ReLU(inplace=True)),
                    ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
                ]
            )
        )
186
187
188
189

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
190
191
192
193
194
195
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
196
                memory_efficient=memory_efficient,
197
            )
198
            self.features.add_module("denseblock%d" % (i + 1), block)
199
200
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
201
202
                trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
                self.features.add_module("transition%d" % (i + 1), trans)
203
204
205
                num_features = num_features // 2

        # Final batch norm
206
        self.features.add_module("norm5", nn.BatchNorm2d(num_features))
207
208
209
210
211
212
213
214
215
216
217
218
219
220

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

221
    def forward(self, x: Tensor) -> Tensor:
222
223
        features = self.features(x)
        out = F.relu(features, inplace=True)
224
225
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
226
227
228
229
        out = self.classifier(out)
        return out


230
def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None:
231
    # '.'s are no longer allowed in module names, but previous _DenseLayer
232
233
234
235
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
236
237
        r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
    )
238

239
    state_dict = weights.get_state_dict(progress=progress)
240
241
242
243
244
245
246
247
248
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)


249
250
251
252
def _densenet(
    growth_rate: int,
    block_config: Tuple[int, int, int, int],
    num_init_features: int,
253
    weights: Optional[WeightsEnum],
254
    progress: bool,
255
    **kwargs: Any,
256
) -> DenseNet:
257
258
259
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

260
    model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
261
262
263
264

    if weights is not None:
        _load_state_dict(model=model, weights=weights, progress=progress)

265
266
267
    return model


268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
_COMMON_META = {
    "task": "image_classification",
    "architecture": "DenseNet",
    "size": (224, 224),
    "min_size": (29, 29),
    "categories": _IMAGENET_CATEGORIES,
    "recipe": "https://github.com/pytorch/vision/pull/116",
}


class DenseNet121_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 7978856,
            "acc@1": 74.434,
            "acc@5": 91.972,
        },
    )
    DEFAULT = IMAGENET1K_V1


class DenseNet161_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 28681000,
            "acc@1": 77.138,
            "acc@5": 93.560,
        },
    )
    DEFAULT = IMAGENET1K_V1


class DenseNet169_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 14149480,
            "acc@1": 75.600,
            "acc@5": 92.806,
        },
    )
    DEFAULT = IMAGENET1K_V1


class DenseNet201_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/densenet201-c1103571.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 20013928,
            "acc@1": 76.896,
            "acc@5": 93.370,
        },
    )
    DEFAULT = IMAGENET1K_V1


@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
Geoff Pleiss's avatar
Geoff Pleiss committed
336
    r"""Densenet-121 model from
337
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
338
    The required minimum input size of the model is 29x29.
Geoff Pleiss's avatar
Geoff Pleiss committed
339
340

    Args:
341
        weights (DenseNet121_Weights, optional): The pretrained weights for the model
342
        progress (bool): If True, displays a progress bar of the download to stderr
343
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
344
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
Geoff Pleiss's avatar
Geoff Pleiss committed
345
    """
346
    weights = DenseNet121_Weights.verify(weights)
Geoff Pleiss's avatar
Geoff Pleiss committed
347

348
    return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
349

350
351
352

@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1))
def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
353
    r"""Densenet-161 model from
354
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
355
    The required minimum input size of the model is 29x29.
Geoff Pleiss's avatar
Geoff Pleiss committed
356
357

    Args:
358
        weights (DenseNet161_Weights, optional): The pretrained weights for the model
359
        progress (bool): If True, displays a progress bar of the download to stderr
360
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
361
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
Geoff Pleiss's avatar
Geoff Pleiss committed
362
    """
363
364
365
    weights = DenseNet161_Weights.verify(weights)

    return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
366
367


368
369
@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1))
def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
370
    r"""Densenet-169 model from
371
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
372
    The required minimum input size of the model is 29x29.
Geoff Pleiss's avatar
Geoff Pleiss committed
373
374

    Args:
375
        weights (DenseNet169_Weights, optional): The pretrained weights for the model
376
        progress (bool): If True, displays a progress bar of the download to stderr
377
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
378
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
Geoff Pleiss's avatar
Geoff Pleiss committed
379
    """
380
    weights = DenseNet169_Weights.verify(weights)
Geoff Pleiss's avatar
Geoff Pleiss committed
381

382
    return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
Geoff Pleiss's avatar
Geoff Pleiss committed
383

384
385
386

@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1))
def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
387
    r"""Densenet-201 model from
388
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
389
    The required minimum input size of the model is 29x29.
Geoff Pleiss's avatar
Geoff Pleiss committed
390
391

    Args:
392
        weights (DenseNet201_Weights, optional): The pretrained weights for the model
393
        progress (bool): If True, displays a progress bar of the download to stderr
394
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
395
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
Geoff Pleiss's avatar
Geoff Pleiss committed
396
    """
397
398
399
    weights = DenseNet201_Weights.verify(weights)

    return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)