squeezenet.py 5.73 KB
Newer Older
1
2
import torch
import torch.nn as nn
3
import torch.nn.init as init
4
from .utils import load_state_dict_from_url
5
from typing import Any
6
7
8
9

__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']

model_urls = {
Nicolas Hug's avatar
Nicolas Hug committed
10
11
    'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth',
    'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth',
12
13
14
15
}


class Fire(nn.Module):
16

17
18
19
20
21
22
    def __init__(
        self,
        inplanes: int,
        squeeze_planes: int,
        expand1x1_planes: int,
        expand3x3_planes: int
23
    ) -> None:
24
25
26
27
28
29
30
31
32
33
34
        super(Fire, self).__init__()
        self.inplanes = inplanes
        self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
        self.squeeze_activation = nn.ReLU(inplace=True)
        self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
                                   kernel_size=1)
        self.expand1x1_activation = nn.ReLU(inplace=True)
        self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
                                   kernel_size=3, padding=1)
        self.expand3x3_activation = nn.ReLU(inplace=True)

35
    def forward(self, x: torch.Tensor) -> torch.Tensor:
36
37
38
39
40
41
42
43
        x = self.squeeze_activation(self.squeeze(x))
        return torch.cat([
            self.expand1x1_activation(self.expand1x1(x)),
            self.expand3x3_activation(self.expand3x3(x))
        ], 1)


class SqueezeNet(nn.Module):
44

45
46
47
48
    def __init__(
        self,
        version: str = '1_0',
        num_classes: int = 1000
49
    ) -> None:
50
51
        super(SqueezeNet, self).__init__()
        self.num_classes = num_classes
52
        if version == '1_0':
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
            self.features = nn.Sequential(
                nn.Conv2d(3, 96, kernel_size=7, stride=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(96, 16, 64, 64),
                Fire(128, 16, 64, 64),
                Fire(128, 32, 128, 128),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(256, 32, 128, 128),
                Fire(256, 48, 192, 192),
                Fire(384, 48, 192, 192),
                Fire(384, 64, 256, 256),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(512, 64, 256, 256),
            )
68
        elif version == '1_1':
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
            self.features = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3, stride=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(64, 16, 64, 64),
                Fire(128, 16, 64, 64),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(128, 32, 128, 128),
                Fire(256, 32, 128, 128),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(256, 48, 192, 192),
                Fire(384, 48, 192, 192),
                Fire(384, 64, 256, 256),
                Fire(512, 64, 256, 256),
            )
84
85
86
87
88
89
90
        else:
            # FIXME: Is this needed? SqueezeNet should only be called from the
            # FIXME: squeezenet1_x() functions
            # FIXME: This checking is not done for the other models
            raise ValueError("Unsupported SqueezeNet version {version}:"
                             "1_0 or 1_1 expected".format(version=version))

Allan Wang's avatar
Allan Wang committed
91
        # Final convolution is initialized differently from the rest
Sri Krishna's avatar
Sri Krishna committed
92
        final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
93
94
95
96
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            final_conv,
            nn.ReLU(inplace=True),
97
            nn.AdaptiveAvgPool2d((1, 1))
98
99
100
101
102
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m is final_conv:
103
                    init.normal_(m.weight, mean=0.0, std=0.01)
104
                else:
105
                    init.kaiming_uniform_(m.weight)
106
                if m.bias is not None:
107
                    init.constant_(m.bias, 0)
108

109
    def forward(self, x: torch.Tensor) -> torch.Tensor:
110
111
        x = self.features(x)
        x = self.classifier(x)
112
        return torch.flatten(x, 1)
113
114


115
def _squeezenet(version: str, pretrained: bool, progress: bool, **kwargs: Any) -> SqueezeNet:
116
117
118
119
120
121
122
123
124
    model = SqueezeNet(version, **kwargs)
    if pretrained:
        arch = 'squeezenet' + version
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


125
def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
126
127
128
    r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
    accuracy with 50x fewer parameters and <0.5MB model size"
    <https://arxiv.org/abs/1602.07360>`_ paper.
129
    The required minimum input size of the model is 21x21.
130
131
132

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
133
        progress (bool): If True, displays a progress bar of the download to stderr
134
    """
135
    return _squeezenet('1_0', pretrained, progress, **kwargs)
136
137


138
def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
139
140
141
142
    r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
    <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
    SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
    than SqueezeNet 1.0, without sacrificing accuracy.
143
    The required minimum input size of the model is 17x17.
144
145
146

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
147
        progress (bool): If True, displays a progress bar of the download to stderr
148
    """
149
    return _squeezenet('1_1', pretrained, progress, **kwargs)