vgg.py 8.09 KB
Newer Older
1
2
from typing import Union, List, Dict, Any, cast

3
import torch
4
import torch.nn as nn
5

6
from .._internally_replaced_utils import load_state_dict_from_url
7
8
9


__all__ = [
10
11
12
13
14
15
16
17
18
    "VGG",
    "vgg11",
    "vgg11_bn",
    "vgg13",
    "vgg13_bn",
    "vgg16",
    "vgg16_bn",
    "vgg19_bn",
    "vgg19",
19
20
21
]


22
model_urls = {
23
24
25
26
27
28
29
30
    "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
    "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
    "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
    "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
    "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
    "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
    "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
    "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
31
32
33
}


Soumith Chintala's avatar
Soumith Chintala committed
34
class VGG(nn.Module):
35
36
37
    def __init__(
        self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
    ) -> None:
38
39
        super(VGG, self).__init__()
        self.features = features
40
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
41
42
43
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
44
            nn.Dropout(p=dropout),
45
46
            nn.Linear(4096, 4096),
            nn.ReLU(True),
47
            nn.Dropout(p=dropout),
Karan Dwivedi's avatar
Karan Dwivedi committed
48
            nn.Linear(4096, num_classes),
49
        )
50
51
        if init_weights:
            self._initialize_weights()
52

53
    def forward(self, x: torch.Tensor) -> torch.Tensor:
54
        x = self.features(x)
55
        x = self.avgpool(x)
56
        x = torch.flatten(x, 1)
57
58
59
        x = self.classifier(x)
        return x

60
    def _initialize_weights(self) -> None:
61
62
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
63
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
64
                if m.bias is not None:
65
                    nn.init.constant_(m.bias, 0)
66
            elif isinstance(m, nn.BatchNorm2d):
67
68
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
69
            elif isinstance(m, nn.Linear):
70
71
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
72

73

74
75
def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
    layers: List[nn.Module] = []
76
77
    in_channels = 3
    for v in cfg:
78
        if v == "M":
79
80
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
81
            v = cast(int, v)
82
83
84
85
86
87
88
89
90
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


91
cfgs: Dict[str, List[Union[str, int]]] = {
92
93
94
95
    "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
    "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
96
97
98
}


99
def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
100
    if pretrained:
101
        kwargs["init_weights"] = False
102
103
    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
    if pretrained:
104
        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
105
106
107
108
        model.load_state_dict(state_dict)
    return model


109
def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
110
    r"""VGG 11-layer model (configuration "A") from
111
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
112
    The required minimum input size of the model is 32x32.
113
114
115

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
116
        progress (bool): If True, displays a progress bar of the download to stderr
117
    """
118
    return _vgg("vgg11", "A", False, pretrained, progress, **kwargs)
119
120


121
def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
122
    r"""VGG 11-layer model (configuration "A") with batch normalization
123
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
124
    The required minimum input size of the model is 32x32.
125
126
127

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
128
        progress (bool): If True, displays a progress bar of the download to stderr
129
    """
130
    return _vgg("vgg11_bn", "A", True, pretrained, progress, **kwargs)
131
132


133
def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
134
    r"""VGG 13-layer model (configuration "B")
135
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
136
    The required minimum input size of the model is 32x32.
137
138
139

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
140
        progress (bool): If True, displays a progress bar of the download to stderr
141
    """
142
    return _vgg("vgg13", "B", False, pretrained, progress, **kwargs)
143
144


145
def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
146
    r"""VGG 13-layer model (configuration "B") with batch normalization
147
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
148
    The required minimum input size of the model is 32x32.
149
150
151

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
152
        progress (bool): If True, displays a progress bar of the download to stderr
153
    """
154
    return _vgg("vgg13_bn", "B", True, pretrained, progress, **kwargs)
155
156


157
def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
158
    r"""VGG 16-layer model (configuration "D")
159
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
160
    The required minimum input size of the model is 32x32.
161
162
163

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
164
        progress (bool): If True, displays a progress bar of the download to stderr
165
    """
166
    return _vgg("vgg16", "D", False, pretrained, progress, **kwargs)
167
168


169
def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
170
    r"""VGG 16-layer model (configuration "D") with batch normalization
171
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
172
    The required minimum input size of the model is 32x32.
173
174
175

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
176
        progress (bool): If True, displays a progress bar of the download to stderr
177
    """
178
    return _vgg("vgg16_bn", "D", True, pretrained, progress, **kwargs)
179
180


181
def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
182
    r"""VGG 19-layer model (configuration "E")
183
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
184
    The required minimum input size of the model is 32x32.
185
186
187

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
188
        progress (bool): If True, displays a progress bar of the download to stderr
189
    """
190
    return _vgg("vgg19", "E", False, pretrained, progress, **kwargs)
191
192


193
def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
194
    r"""VGG 19-layer model (configuration 'E') with batch normalization
195
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
196
    The required minimum input size of the model is 32x32.
197
198
199

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
200
        progress (bool): If True, displays a progress bar of the download to stderr
201
    """
202
    return _vgg("vgg19_bn", "E", True, pretrained, progress, **kwargs)