alexnet.py 4.58 KB
Newer Older
1
2
from functools import partial
from typing import Any, Optional
3

4
import torch
5
import torch.nn as nn
6

7
from ..transforms._presets import ImageClassification
8
from ..utils import _log_api_usage_once
9
from ._api import register_model, Weights, WeightsEnum
10
from ._meta import _IMAGENET_CATEGORIES
11
from ._utils import _ovewrite_named_param, handle_legacy_interface
12
13


14
__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
15
16


Soumith Chintala's avatar
Soumith Chintala committed
17
class AlexNet(nn.Module):
18
    def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:
19
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
20
        _log_api_usage_once(self)
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
36
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
37
        self.classifier = nn.Sequential(
38
            nn.Dropout(p=dropout),
39
40
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
41
            nn.Dropout(p=dropout),
42
43
44
45
46
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

47
    def forward(self, x: torch.Tensor) -> torch.Tensor:
48
        x = self.features(x)
49
        x = self.avgpool(x)
50
        x = torch.flatten(x, 1)
51
52
53
54
        x = self.classifier(x)
        return x


55
56
57
58
59
60
61
62
63
class AlexNet_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            "num_params": 61100840,
            "min_size": (63, 63),
            "categories": _IMAGENET_CATEGORIES,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
64
65
66
67
68
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 56.522,
                    "acc@5": 79.066,
                }
69
            },
70
71
            "_ops": 0.714,
            "_weight_size": 233.087,
72
73
74
            "_docs": """
                These weights reproduce closely the results of the paper using a simplified training recipe.
            """,
75
76
77
78
79
        },
    )
    DEFAULT = IMAGENET1K_V1


80
@register_model()
81
82
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
83
    """AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks <https://arxiv.org/abs/1404.5997>`__.
Nicolas Hug's avatar
Nicolas Hug committed
84

85
86
87
88
89
90
91
    .. note::
        AlexNet was originally introduced in the `ImageNet Classification with
        Deep Convolutional Neural Networks
        <https://papers.nips.cc/paper/2012/hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html>`__
        paper. Our implementation is based instead on the "One weird trick"
        paper above.

Nicolas Hug's avatar
Nicolas Hug committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    Args:
        weights (:class:`~torchvision.models.AlexNet_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.AlexNet_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.squeezenet.AlexNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.AlexNet_Weights
        :members:
107
    """
108

109
110
111
112
113
    weights = AlexNet_Weights.verify(weights)

    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

114
    model = AlexNet(**kwargs)
115
116
117
118

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

119
    return model
120
121
122
123
124
125
126
127
128
129
130


# The dictionary below is internal implementation detail and will be removed in v0.15
from ._utils import _ModelURLs


model_urls = _ModelURLs(
    {
        "alexnet": AlexNet_Weights.IMAGENET1K_V1.url,
    }
)