squeezenet.py 5.51 KB
Newer Older
1
2
from typing import Any

3
4
import torch
import torch.nn as nn
5
import torch.nn.init as init
6

7
from .._internally_replaced_utils import load_state_dict_from_url
8

9
__all__ = ["SqueezeNet", "squeezenet1_0", "squeezenet1_1"]
10
11

model_urls = {
12
13
    "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
    "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
14
15
16
17
}


class Fire(nn.Module):
18
    def __init__(self, inplanes: int, squeeze_planes: int, expand1x1_planes: int, expand3x3_planes: int) -> None:
19
20
21
22
        super(Fire, self).__init__()
        self.inplanes = inplanes
        self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
        self.squeeze_activation = nn.ReLU(inplace=True)
23
        self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)
24
        self.expand1x1_activation = nn.ReLU(inplace=True)
25
        self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1)
26
27
        self.expand3x3_activation = nn.ReLU(inplace=True)

28
    def forward(self, x: torch.Tensor) -> torch.Tensor:
29
        x = self.squeeze_activation(self.squeeze(x))
30
31
32
        return torch.cat(
            [self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], 1
        )
33
34
35


class SqueezeNet(nn.Module):
36
    def __init__(self, version: str = "1_0", num_classes: int = 1000, dropout: float = 0.5) -> None:
37
38
        super(SqueezeNet, self).__init__()
        self.num_classes = num_classes
39
        if version == "1_0":
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
            self.features = nn.Sequential(
                nn.Conv2d(3, 96, kernel_size=7, stride=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(96, 16, 64, 64),
                Fire(128, 16, 64, 64),
                Fire(128, 32, 128, 128),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(256, 32, 128, 128),
                Fire(256, 48, 192, 192),
                Fire(384, 48, 192, 192),
                Fire(384, 64, 256, 256),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(512, 64, 256, 256),
            )
55
        elif version == "1_1":
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
            self.features = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3, stride=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(64, 16, 64, 64),
                Fire(128, 16, 64, 64),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(128, 32, 128, 128),
                Fire(256, 32, 128, 128),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(256, 48, 192, 192),
                Fire(384, 48, 192, 192),
                Fire(384, 64, 256, 256),
                Fire(512, 64, 256, 256),
            )
71
72
73
74
        else:
            # FIXME: Is this needed? SqueezeNet should only be called from the
            # FIXME: squeezenet1_x() functions
            # FIXME: This checking is not done for the other models
75
            raise ValueError("Unsupported SqueezeNet version {version}:" "1_0 or 1_1 expected".format(version=version))
76

Allan Wang's avatar
Allan Wang committed
77
        # Final convolution is initialized differently from the rest
Sri Krishna's avatar
Sri Krishna committed
78
        final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
79
        self.classifier = nn.Sequential(
80
            nn.Dropout(p=dropout), final_conv, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1))
81
82
83
84
85
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m is final_conv:
86
                    init.normal_(m.weight, mean=0.0, std=0.01)
87
                else:
88
                    init.kaiming_uniform_(m.weight)
89
                if m.bias is not None:
90
                    init.constant_(m.bias, 0)
91

92
    def forward(self, x: torch.Tensor) -> torch.Tensor:
93
94
        x = self.features(x)
        x = self.classifier(x)
95
        return torch.flatten(x, 1)
96
97


98
def _squeezenet(version: str, pretrained: bool, progress: bool, **kwargs: Any) -> SqueezeNet:
99
100
    model = SqueezeNet(version, **kwargs)
    if pretrained:
101
102
        arch = "squeezenet" + version
        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
103
104
105
106
        model.load_state_dict(state_dict)
    return model


107
def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
108
109
110
    r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
    accuracy with 50x fewer parameters and <0.5MB model size"
    <https://arxiv.org/abs/1602.07360>`_ paper.
111
    The required minimum input size of the model is 21x21.
112
113
114

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
115
        progress (bool): If True, displays a progress bar of the download to stderr
116
    """
117
    return _squeezenet("1_0", pretrained, progress, **kwargs)
118
119


120
def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
121
122
123
124
    r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
    <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
    SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
    than SqueezeNet 1.0, without sacrificing accuracy.
125
    The required minimum input size of the model is 17x17.
126
127
128

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
129
        progress (bool): If True, displays a progress bar of the download to stderr
130
    """
131
    return _squeezenet("1_1", pretrained, progress, **kwargs)