squeezenet.py 8.36 KB
Newer Older
1
2
from functools import partial
from typing import Any, Optional
3

4
5
import torch
import torch.nn as nn
6
import torch.nn.init as init
7

8
from ..transforms._presets import ImageClassification, InterpolationMode
9
from ..utils import _log_api_usage_once
10
11
12
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
13
14


15
__all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"]
16
17
18


class Fire(nn.Module):
19
    def __init__(self, inplanes: int, squeeze_planes: int, expand1x1_planes: int, expand3x3_planes: int) -> None:
20
        super().__init__()
21
22
23
        self.inplanes = inplanes
        self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
        self.squeeze_activation = nn.ReLU(inplace=True)
24
        self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)
25
        self.expand1x1_activation = nn.ReLU(inplace=True)
26
        self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1)
27
28
        self.expand3x3_activation = nn.ReLU(inplace=True)

29
    def forward(self, x: torch.Tensor) -> torch.Tensor:
30
        x = self.squeeze_activation(self.squeeze(x))
31
32
33
        return torch.cat(
            [self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], 1
        )
34
35
36


class SqueezeNet(nn.Module):
37
    def __init__(self, version: str = "1_0", num_classes: int = 1000, dropout: float = 0.5) -> None:
38
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
39
        _log_api_usage_once(self)
40
        self.num_classes = num_classes
41
        if version == "1_0":
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
            self.features = nn.Sequential(
                nn.Conv2d(3, 96, kernel_size=7, stride=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(96, 16, 64, 64),
                Fire(128, 16, 64, 64),
                Fire(128, 32, 128, 128),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(256, 32, 128, 128),
                Fire(256, 48, 192, 192),
                Fire(384, 48, 192, 192),
                Fire(384, 64, 256, 256),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(512, 64, 256, 256),
            )
57
        elif version == "1_1":
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
            self.features = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3, stride=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(64, 16, 64, 64),
                Fire(128, 16, 64, 64),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(128, 32, 128, 128),
                Fire(256, 32, 128, 128),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(256, 48, 192, 192),
                Fire(384, 48, 192, 192),
                Fire(384, 64, 256, 256),
                Fire(512, 64, 256, 256),
            )
73
74
75
76
        else:
            # FIXME: Is this needed? SqueezeNet should only be called from the
            # FIXME: squeezenet1_x() functions
            # FIXME: This checking is not done for the other models
77
            raise ValueError(f"Unsupported SqueezeNet version {version}: 1_0 or 1_1 expected")
78

Allan Wang's avatar
Allan Wang committed
79
        # Final convolution is initialized differently from the rest
Sri Krishna's avatar
Sri Krishna committed
80
        final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
81
        self.classifier = nn.Sequential(
82
            nn.Dropout(p=dropout), final_conv, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1))
83
84
85
86
87
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m is final_conv:
88
                    init.normal_(m.weight, mean=0.0, std=0.01)
89
                else:
90
                    init.kaiming_uniform_(m.weight)
91
                if m.bias is not None:
92
                    init.constant_(m.bias, 0)
93

94
    def forward(self, x: torch.Tensor) -> torch.Tensor:
95
96
        x = self.features(x)
        x = self.classifier(x)
97
        return torch.flatten(x, 1)
98
99


100
101
102
103
104
105
106
107
108
def _squeezenet(
    version: str,
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> SqueezeNet:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

109
    model = SqueezeNet(version, **kwargs)
110
111
112
113

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

114
115
116
    return model


117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
_COMMON_META = {
    "task": "image_classification",
    "architecture": "SqueezeNet",
    "publication_year": 2016,
    "size": (224, 224),
    "categories": _IMAGENET_CATEGORIES,
    "interpolation": InterpolationMode.BILINEAR,
    "recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717",
}


class SqueezeNet1_0_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "min_size": (21, 21),
            "num_params": 1248424,
            "acc@1": 58.092,
            "acc@5": 80.420,
        },
    )
    DEFAULT = IMAGENET1K_V1


class SqueezeNet1_1_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "min_size": (17, 17),
            "num_params": 1235496,
            "acc@1": 58.178,
            "acc@5": 80.624,
        },
    )
    DEFAULT = IMAGENET1K_V1


@handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.IMAGENET1K_V1))
def squeezenet1_0(
    *, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> SqueezeNet:
Nicolas Hug's avatar
Nicolas Hug committed
162
163
    """SqueezeNet model architecture from the `SqueezeNet: AlexNet-level
    accuracy with 50x fewer parameters and <0.5MB model size
164
    <https://arxiv.org/abs/1602.07360>`_ paper.
Nicolas Hug's avatar
Nicolas Hug committed
165

166
    The required minimum input size of the model is 21x21.
167
168

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        weights (:class:`~torchvision.models.SqueezeNet1_0_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.SqueezeNet1_0_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.squeezenet.SqueezeNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/squeezenet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.SqueezeNet1_0_Weights
        :members:
183
    """
184
185
    weights = SqueezeNet1_0_Weights.verify(weights)
    return _squeezenet("1_0", weights, progress, **kwargs)
186
187


188
189
190
191
@handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.IMAGENET1K_V1))
def squeezenet1_1(
    *, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any
) -> SqueezeNet:
Nicolas Hug's avatar
Nicolas Hug committed
192
    """SqueezeNet 1.1 model from the `official SqueezeNet repo
193
    <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
Nicolas Hug's avatar
Nicolas Hug committed
194

195
196
    SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
    than SqueezeNet 1.0, without sacrificing accuracy.
197
    The required minimum input size of the model is 17x17.
198
199

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        weights (:class:`~torchvision.models.SqueezeNet1_1_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.SqueezeNet1_1_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.squeezenet.SqueezeNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/squeezenet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.SqueezeNet1_1_Weights
        :members:
214
    """
215
216
    weights = SqueezeNet1_1_Weights.verify(weights)
    return _squeezenet("1_1", weights, progress, **kwargs)