vgg.py 17.8 KB
Newer Older
1
2
from functools import partial
from typing import Union, List, Dict, Any, Optional, cast
3

4
import torch
5
import torch.nn as nn
6

7
from ..transforms._presets import ImageClassification
8
from ..utils import _log_api_usage_once
9
10
11
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
12
13
14


__all__ = [
15
    "VGG",
16
17
18
19
20
21
22
23
    "VGG11_Weights",
    "VGG11_BN_Weights",
    "VGG13_Weights",
    "VGG13_BN_Weights",
    "VGG16_Weights",
    "VGG16_BN_Weights",
    "VGG19_Weights",
    "VGG19_BN_Weights",
24
25
26
27
28
29
30
    "vgg11",
    "vgg11_bn",
    "vgg13",
    "vgg13_bn",
    "vgg16",
    "vgg16_bn",
    "vgg19",
31
    "vgg19_bn",
32
33
34
]


Soumith Chintala's avatar
Soumith Chintala committed
35
class VGG(nn.Module):
36
37
38
    def __init__(
        self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
    ) -> None:
39
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
40
        _log_api_usage_once(self)
41
        self.features = features
42
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
43
44
45
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
46
            nn.Dropout(p=dropout),
47
48
            nn.Linear(4096, 4096),
            nn.ReLU(True),
49
            nn.Dropout(p=dropout),
Karan Dwivedi's avatar
Karan Dwivedi committed
50
            nn.Linear(4096, num_classes),
51
        )
52
        if init_weights:
53
54
55
56
57
58
59
60
61
62
63
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.Linear):
                    nn.init.normal_(m.weight, 0, 0.01)
                    nn.init.constant_(m.bias, 0)
64

65
    def forward(self, x: torch.Tensor) -> torch.Tensor:
66
        x = self.features(x)
67
        x = self.avgpool(x)
68
        x = torch.flatten(x, 1)
69
70
71
72
        x = self.classifier(x)
        return x


73
74
def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
    layers: List[nn.Module] = []
75
76
    in_channels = 3
    for v in cfg:
77
        if v == "M":
78
79
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
80
            v = cast(int, v)
81
82
83
84
85
86
87
88
89
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


90
cfgs: Dict[str, List[Union[str, int]]] = {
91
92
93
94
    "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
    "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
95
96
97
}


98
99
def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG:
    if weights is not None:
100
        kwargs["init_weights"] = False
101
102
        if weights.meta["categories"] is not None:
            _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
103
    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
104
105
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
106
107
108
    return model


109
110
111
112
113
114
115
116
117
118
119
120
121
122
_COMMON_META = {
    "min_size": (32, 32),
    "categories": _IMAGENET_CATEGORIES,
    "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
}


class VGG11_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/vgg11-8a719046.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 132863336,
123
124
125
126
            "metrics": {
                "acc@1": 69.020,
                "acc@5": 88.628,
            },
127
128
129
130
131
132
133
134
135
136
137
138
        },
    )
    DEFAULT = IMAGENET1K_V1


class VGG11_BN_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 132868840,
139
140
141
142
            "metrics": {
                "acc@1": 70.370,
                "acc@5": 89.810,
            },
143
144
145
146
147
148
149
150
151
152
153
154
        },
    )
    DEFAULT = IMAGENET1K_V1


class VGG13_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/vgg13-19584684.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 133047848,
155
156
157
158
            "metrics": {
                "acc@1": 69.928,
                "acc@5": 89.246,
            },
159
160
161
162
163
164
165
166
167
168
169
170
        },
    )
    DEFAULT = IMAGENET1K_V1


class VGG13_BN_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 133053736,
171
172
173
174
            "metrics": {
                "acc@1": 71.586,
                "acc@5": 90.374,
            },
175
176
177
178
179
180
181
182
183
184
185
186
        },
    )
    DEFAULT = IMAGENET1K_V1


class VGG16_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/vgg16-397923af.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 138357544,
187
188
189
190
            "metrics": {
                "acc@1": 71.592,
                "acc@5": 90.382,
            },
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        },
    )
    # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the
    # same input standardization method as the paper. Only the `features` weights have proper values, those on the
    # `classifier` module are filled with nans.
    IMAGENET1K_FEATURES = Weights(
        url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth",
        transforms=partial(
            ImageClassification,
            crop_size=224,
            mean=(0.48235, 0.45882, 0.40784),
            std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0),
        ),
        meta={
            **_COMMON_META,
            "num_params": 138357544,
            "categories": None,
            "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd",
209
210
211
212
            "metrics": {
                "acc@1": float("nan"),
                "acc@5": float("nan"),
            },
213
214
215
216
217
218
219
220
221
222
223
224
        },
    )
    DEFAULT = IMAGENET1K_V1


class VGG16_BN_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 138365992,
225
226
227
228
            "metrics": {
                "acc@1": 73.360,
                "acc@5": 91.516,
            },
229
230
231
232
233
234
235
236
237
238
239
240
        },
    )
    DEFAULT = IMAGENET1K_V1


class VGG19_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 143667240,
241
242
243
244
            "metrics": {
                "acc@1": 72.376,
                "acc@5": 90.876,
            },
245
246
247
248
249
250
251
252
253
254
255
256
        },
    )
    DEFAULT = IMAGENET1K_V1


class VGG19_BN_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 143678248,
257
258
259
260
            "metrics": {
                "acc@1": 74.218,
                "acc@5": 91.842,
            },
261
262
263
264
265
266
267
        },
    )
    DEFAULT = IMAGENET1K_V1


@handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1))
def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
268
    """VGG-11 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
269
270

    Args:
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        weights (:class:`~torchvision.models.VGG11_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.VGG11_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.VGG11_Weights
        :members:
285
    """
286
    weights = VGG11_Weights.verify(weights)
287

288
    return _vgg("A", False, weights, progress, **kwargs)
289

290
291
292

@handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1))
def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
293
    """VGG-11-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
294
295

    Args:
296
297
298
299
300
301
302
303
304
305
306
307
308
309
        weights (:class:`~torchvision.models.VGG11_BN_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.VGG11_BN_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.VGG11_BN_Weights
        :members:
310
    """
311
312
313
    weights = VGG11_BN_Weights.verify(weights)

    return _vgg("A", True, weights, progress, **kwargs)
314
315


316
317
@handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1))
def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
318
    """VGG-13 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
319
320

    Args:
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        weights (:class:`~torchvision.models.VGG13_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.VGG13_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.VGG13_Weights
        :members:
335
    """
336
    weights = VGG13_Weights.verify(weights)
337

338
    return _vgg("B", False, weights, progress, **kwargs)
339

340
341
342

@handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1))
def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
343
    """VGG-13-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
344
345

    Args:
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        weights (:class:`~torchvision.models.VGG13_BN_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.VGG13_BN_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.VGG13_BN_Weights
        :members:
360
    """
361
362
363
    weights = VGG13_BN_Weights.verify(weights)

    return _vgg("B", True, weights, progress, **kwargs)
364
365


366
367
@handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1))
def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
368
    """VGG-16 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
369
370

    Args:
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        weights (:class:`~torchvision.models.VGG16_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.VGG16_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.VGG16_Weights
        :members:
385
    """
386
    weights = VGG16_Weights.verify(weights)
387

388
    return _vgg("D", False, weights, progress, **kwargs)
389

390
391
392

@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1))
def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
393
    """VGG-16-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
394
395

    Args:
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        weights (:class:`~torchvision.models.VGG16_BN_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.VGG16_BN_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.VGG16_BN_Weights
        :members:
410
    """
411
412
413
    weights = VGG16_BN_Weights.verify(weights)

    return _vgg("D", True, weights, progress, **kwargs)
414
415


416
417
@handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1))
def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
418
    """VGG-19 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
419
420

    Args:
421
422
423
424
425
426
427
428
429
430
431
432
433
434
        weights (:class:`~torchvision.models.VGG19_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.VGG19_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.VGG19_Weights
        :members:
435
    """
436
    weights = VGG19_Weights.verify(weights)
437

438
    return _vgg("E", False, weights, progress, **kwargs)
439

440
441
442

@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1))
def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
443
    """VGG-19_BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
444
445

    Args:
446
447
448
449
450
451
452
453
454
455
456
457
458
459
        weights (:class:`~torchvision.models.VGG19_BN_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.VGG19_BN_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.VGG19_BN_Weights
        :members:
460
    """
461
462
463
    weights = VGG19_BN_Weights.verify(weights)

    return _vgg("E", True, weights, progress, **kwargs)
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481


# The dictionary below is internal implementation detail and will be removed in v0.15
from ._utils import _ModelURLs


model_urls = _ModelURLs(
    {
        "vgg11": VGG11_Weights.IMAGENET1K_V1.url,
        "vgg13": VGG13_Weights.IMAGENET1K_V1.url,
        "vgg16": VGG16_Weights.IMAGENET1K_V1.url,
        "vgg19": VGG19_Weights.IMAGENET1K_V1.url,
        "vgg11_bn": VGG11_BN_Weights.IMAGENET1K_V1.url,
        "vgg13_bn": VGG13_BN_Weights.IMAGENET1K_V1.url,
        "vgg16_bn": VGG16_BN_Weights.IMAGENET1K_V1.url,
        "vgg19_bn": VGG19_BN_Weights.IMAGENET1K_V1.url,
    }
)