squeezenet.py 5.53 KB
Newer Older
1
2
from typing import Any

3
4
import torch
import torch.nn as nn
5
import torch.nn.init as init
6

7
from .._internally_replaced_utils import load_state_dict_from_url
8
from ..utils import _log_api_usage_once
9

10
__all__ = ["SqueezeNet", "squeezenet1_0", "squeezenet1_1"]
11
12

model_urls = {
13
14
    "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
    "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
15
16
17
18
}


class Fire(nn.Module):
19
    def __init__(self, inplanes: int, squeeze_planes: int, expand1x1_planes: int, expand3x3_planes: int) -> None:
20
        super().__init__()
21
22
23
        self.inplanes = inplanes
        self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
        self.squeeze_activation = nn.ReLU(inplace=True)
24
        self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)
25
        self.expand1x1_activation = nn.ReLU(inplace=True)
26
        self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1)
27
28
        self.expand3x3_activation = nn.ReLU(inplace=True)

29
    def forward(self, x: torch.Tensor) -> torch.Tensor:
30
        x = self.squeeze_activation(self.squeeze(x))
31
32
33
        return torch.cat(
            [self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], 1
        )
34
35
36


class SqueezeNet(nn.Module):
37
    def __init__(self, version: str = "1_0", num_classes: int = 1000, dropout: float = 0.5) -> None:
38
        super().__init__()
39
        _log_api_usage_once(self)
40
        self.num_classes = num_classes
41
        if version == "1_0":
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
            self.features = nn.Sequential(
                nn.Conv2d(3, 96, kernel_size=7, stride=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(96, 16, 64, 64),
                Fire(128, 16, 64, 64),
                Fire(128, 32, 128, 128),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(256, 32, 128, 128),
                Fire(256, 48, 192, 192),
                Fire(384, 48, 192, 192),
                Fire(384, 64, 256, 256),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(512, 64, 256, 256),
            )
57
        elif version == "1_1":
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
            self.features = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3, stride=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(64, 16, 64, 64),
                Fire(128, 16, 64, 64),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(128, 32, 128, 128),
                Fire(256, 32, 128, 128),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(256, 48, 192, 192),
                Fire(384, 48, 192, 192),
                Fire(384, 64, 256, 256),
                Fire(512, 64, 256, 256),
            )
73
74
75
76
        else:
            # FIXME: Is this needed? SqueezeNet should only be called from the
            # FIXME: squeezenet1_x() functions
            # FIXME: This checking is not done for the other models
77
            raise ValueError(f"Unsupported SqueezeNet version {version}: 1_0 or 1_1 expected")
78

Allan Wang's avatar
Allan Wang committed
79
        # Final convolution is initialized differently from the rest
Sri Krishna's avatar
Sri Krishna committed
80
        final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
81
        self.classifier = nn.Sequential(
82
            nn.Dropout(p=dropout), final_conv, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1))
83
84
85
86
87
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m is final_conv:
88
                    init.normal_(m.weight, mean=0.0, std=0.01)
89
                else:
90
                    init.kaiming_uniform_(m.weight)
91
                if m.bias is not None:
92
                    init.constant_(m.bias, 0)
93

94
    def forward(self, x: torch.Tensor) -> torch.Tensor:
95
96
        x = self.features(x)
        x = self.classifier(x)
97
        return torch.flatten(x, 1)
98
99


100
def _squeezenet(version: str, pretrained: bool, progress: bool, **kwargs: Any) -> SqueezeNet:
101
102
    model = SqueezeNet(version, **kwargs)
    if pretrained:
103
104
        arch = "squeezenet" + version
        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
105
106
107
108
        model.load_state_dict(state_dict)
    return model


109
def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
110
111
112
    r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
    accuracy with 50x fewer parameters and <0.5MB model size"
    <https://arxiv.org/abs/1602.07360>`_ paper.
113
    The required minimum input size of the model is 21x21.
114
115
116

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
117
        progress (bool): If True, displays a progress bar of the download to stderr
118
    """
119
    return _squeezenet("1_0", pretrained, progress, **kwargs)
120
121


122
def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
123
124
125
126
    r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
    <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
    SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
    than SqueezeNet 1.0, without sacrificing accuracy.
127
    The required minimum input size of the model is 17x17.
128
129
130

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
131
        progress (bool): If True, displays a progress bar of the download to stderr
132
    """
133
    return _squeezenet("1_1", pretrained, progress, **kwargs)