alexnet.py 3.87 KB
Newer Older
1
2
from functools import partial
from typing import Any, Optional
3

4
import torch
5
import torch.nn as nn
6

7
from ..transforms._presets import ImageClassification
8
from ..utils import _log_api_usage_once
9
10
11
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
12
13


14
__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
15
16


Soumith Chintala's avatar
Soumith Chintala committed
17
class AlexNet(nn.Module):
18
    def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:
19
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
20
        _log_api_usage_once(self)
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
36
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
37
        self.classifier = nn.Sequential(
38
            nn.Dropout(p=dropout),
39
40
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
41
            nn.Dropout(p=dropout),
42
43
44
45
46
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

47
    def forward(self, x: torch.Tensor) -> torch.Tensor:
48
        x = self.features(x)
49
        x = self.avgpool(x)
50
        x = torch.flatten(x, 1)
51
52
53
54
        x = self.classifier(x)
        return x


55
56
57
58
59
60
61
62
63
class AlexNet_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            "num_params": 61100840,
            "min_size": (63, 63),
            "categories": _IMAGENET_CATEGORIES,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
64
65
66
67
            "metrics": {
                "acc@1": 56.522,
                "acc@5": 79.066,
            },
68
69
70
71
72
73
74
        },
    )
    DEFAULT = IMAGENET1K_V1


@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
Nicolas Hug's avatar
Nicolas Hug committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    """AlexNet model architecture from the `ImageNet Classification with Deep Convolutional Neural Networks
    <https://papers.nips.cc/paper/2012/hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html>`__ paper.

    The required minimum input size of the model is 63x63.

    Args:
        weights (:class:`~torchvision.models.AlexNet_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.AlexNet_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.squeezenet.AlexNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.AlexNet_Weights
        :members:
95
    """
96

97
98
99
100
101
    weights = AlexNet_Weights.verify(weights)

    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

102
    model = AlexNet(**kwargs)
103
104
105
106

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

107
    return model