alexnet.py 4.28 KB
Newer Older
1
2
from functools import partial
from typing import Any, Optional
3

4
import torch
5
import torch.nn as nn
6

7
from ..transforms._presets import ImageClassification
8
from ..utils import _log_api_usage_once
9
10
11
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
12
13


14
__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
15
16


Soumith Chintala's avatar
Soumith Chintala committed
17
class AlexNet(nn.Module):
18
    def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:
19
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
20
        _log_api_usage_once(self)
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
36
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
37
        self.classifier = nn.Sequential(
38
            nn.Dropout(p=dropout),
39
40
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
41
            nn.Dropout(p=dropout),
42
43
44
45
46
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

47
    def forward(self, x: torch.Tensor) -> torch.Tensor:
48
        x = self.features(x)
49
        x = self.avgpool(x)
50
        x = torch.flatten(x, 1)
51
52
53
54
        x = self.classifier(x)
        return x


55
56
57
58
59
60
61
62
63
class AlexNet_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            "num_params": 61100840,
            "min_size": (63, 63),
            "categories": _IMAGENET_CATEGORIES,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
64
65
66
67
            "metrics": {
                "acc@1": 56.522,
                "acc@5": 79.066,
            },
68
69
70
71
72
73
74
        },
    )
    DEFAULT = IMAGENET1K_V1


@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
75
    """AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks <https://arxiv.org/abs/1404.5997>`__.
Nicolas Hug's avatar
Nicolas Hug committed
76

77
78
79
80
81
82
83
    .. note::
        AlexNet was originally introduced in the `ImageNet Classification with
        Deep Convolutional Neural Networks
        <https://papers.nips.cc/paper/2012/hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html>`__
        paper. Our implementation is based instead on the "One weird trick"
        paper above.

Nicolas Hug's avatar
Nicolas Hug committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    Args:
        weights (:class:`~torchvision.models.AlexNet_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.AlexNet_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.squeezenet.AlexNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.AlexNet_Weights
        :members:
99
    """
100

101
102
103
104
105
    weights = AlexNet_Weights.verify(weights)

    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

106
    model = AlexNet(**kwargs)
107
108
109
110

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

111
    return model
112
113
114
115
116
117
118
119
120
121
122


# The dictionary below is internal implementation detail and will be removed in v0.15
from ._utils import _ModelURLs


model_urls = _ModelURLs(
    {
        "alexnet": AlexNet_Weights.IMAGENET1K_V1.url,
    }
)