"docs/source/en/api/pipelines/overview.mdx" did not exist on "b25843e799d42246d7a60808259dc5c28163446f"
alexnet.py 4.33 KB
Newer Older
1
2
from functools import partial
from typing import Any, Optional
3

4
import torch
5
import torch.nn as nn
6

7
from ..transforms._presets import ImageClassification
8
from ..utils import _log_api_usage_once
9
10
11
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
12
13


14
__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
15
16


Soumith Chintala's avatar
Soumith Chintala committed
17
class AlexNet(nn.Module):
18
    def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:
19
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
20
        _log_api_usage_once(self)
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
36
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
37
        self.classifier = nn.Sequential(
38
            nn.Dropout(p=dropout),
39
40
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
41
            nn.Dropout(p=dropout),
42
43
44
45
46
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

47
    def forward(self, x: torch.Tensor) -> torch.Tensor:
48
        x = self.features(x)
49
        x = self.avgpool(x)
50
        x = torch.flatten(x, 1)
51
52
53
54
        x = self.classifier(x)
        return x


55
56
57
58
59
60
61
62
63
class AlexNet_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            "num_params": 61100840,
            "min_size": (63, 63),
            "categories": _IMAGENET_CATEGORIES,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
64
65
66
67
            "metrics": {
                "acc@1": 56.522,
                "acc@5": 79.066,
            },
68
69
70
71
72
73
74
        },
    )
    DEFAULT = IMAGENET1K_V1


@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
75
    """AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks <https://arxiv.org/abs/1404.5997>`__.
Nicolas Hug's avatar
Nicolas Hug committed
76
77
78

    The required minimum input size of the model is 63x63.

79
80
81
82
83
84
85
    .. note::
        AlexNet was originally introduced in the `ImageNet Classification with
        Deep Convolutional Neural Networks
        <https://papers.nips.cc/paper/2012/hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html>`__
        paper. Our implementation is based instead on the "One weird trick"
        paper above.

Nicolas Hug's avatar
Nicolas Hug committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    Args:
        weights (:class:`~torchvision.models.AlexNet_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.AlexNet_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.squeezenet.AlexNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.AlexNet_Weights
        :members:
101
    """
102

103
104
105
106
107
    weights = AlexNet_Weights.verify(weights)

    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

108
    model = AlexNet(**kwargs)
109
110
111
112

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

113
    return model
114
115
116
117
118
119
120
121
122
123
124


# The dictionary below is internal implementation detail and will be removed in v0.15
from ._utils import _ModelURLs


model_urls = _ModelURLs(
    {
        "alexnet": AlexNet_Weights.IMAGENET1K_V1.url,
    }
)