vgg.py 8.12 KB
Newer Older
1
2
from typing import Union, List, Dict, Any, cast

3
import torch
4
import torch.nn as nn
5

6
from .._internally_replaced_utils import load_state_dict_from_url
7
from ..utils import _log_api_usage_once
8
9
10


__all__ = [
11
12
13
14
15
16
17
18
19
    "VGG",
    "vgg11",
    "vgg11_bn",
    "vgg13",
    "vgg13_bn",
    "vgg16",
    "vgg16_bn",
    "vgg19_bn",
    "vgg19",
20
21
22
]


23
model_urls = {
24
25
26
27
28
29
30
31
    "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
    "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
    "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
    "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
    "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
    "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
    "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
    "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
32
33
34
}


Soumith Chintala's avatar
Soumith Chintala committed
35
class VGG(nn.Module):
36
37
38
    def __init__(
        self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
    ) -> None:
39
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
40
        _log_api_usage_once(self)
41
        self.features = features
42
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
43
44
45
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
46
            nn.Dropout(p=dropout),
47
48
            nn.Linear(4096, 4096),
            nn.ReLU(True),
49
            nn.Dropout(p=dropout),
Karan Dwivedi's avatar
Karan Dwivedi committed
50
            nn.Linear(4096, num_classes),
51
        )
52
        if init_weights:
53
54
55
56
57
58
59
60
61
62
63
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.Linear):
                    nn.init.normal_(m.weight, 0, 0.01)
                    nn.init.constant_(m.bias, 0)
64

65
    def forward(self, x: torch.Tensor) -> torch.Tensor:
66
        x = self.features(x)
67
        x = self.avgpool(x)
68
        x = torch.flatten(x, 1)
69
70
71
72
        x = self.classifier(x)
        return x


73
74
def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
    layers: List[nn.Module] = []
75
76
    in_channels = 3
    for v in cfg:
77
        if v == "M":
78
79
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
80
            v = cast(int, v)
81
82
83
84
85
86
87
88
89
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


90
cfgs: Dict[str, List[Union[str, int]]] = {
91
92
93
94
    "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
    "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
95
96
97
}


98
def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
99
    if pretrained:
100
        kwargs["init_weights"] = False
101
102
    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
    if pretrained:
103
        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
104
105
106
107
        model.load_state_dict(state_dict)
    return model


108
def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
109
    r"""VGG 11-layer model (configuration "A") from
110
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
111
    The required minimum input size of the model is 32x32.
112
113
114

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
115
        progress (bool): If True, displays a progress bar of the download to stderr
116
    """
117
    return _vgg("vgg11", "A", False, pretrained, progress, **kwargs)
118
119


120
def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
121
    r"""VGG 11-layer model (configuration "A") with batch normalization
122
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
123
    The required minimum input size of the model is 32x32.
124
125
126

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
127
        progress (bool): If True, displays a progress bar of the download to stderr
128
    """
129
    return _vgg("vgg11_bn", "A", True, pretrained, progress, **kwargs)
130
131


132
def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
133
    r"""VGG 13-layer model (configuration "B")
134
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
135
    The required minimum input size of the model is 32x32.
136
137
138

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
139
        progress (bool): If True, displays a progress bar of the download to stderr
140
    """
141
    return _vgg("vgg13", "B", False, pretrained, progress, **kwargs)
142
143


144
def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
145
    r"""VGG 13-layer model (configuration "B") with batch normalization
146
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
147
    The required minimum input size of the model is 32x32.
148
149
150

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
151
        progress (bool): If True, displays a progress bar of the download to stderr
152
    """
153
    return _vgg("vgg13_bn", "B", True, pretrained, progress, **kwargs)
154
155


156
def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
157
    r"""VGG 16-layer model (configuration "D")
158
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
159
    The required minimum input size of the model is 32x32.
160
161
162

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
163
        progress (bool): If True, displays a progress bar of the download to stderr
164
    """
165
    return _vgg("vgg16", "D", False, pretrained, progress, **kwargs)
166
167


168
def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
169
    r"""VGG 16-layer model (configuration "D") with batch normalization
170
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
171
    The required minimum input size of the model is 32x32.
172
173
174

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
175
        progress (bool): If True, displays a progress bar of the download to stderr
176
    """
177
    return _vgg("vgg16_bn", "D", True, pretrained, progress, **kwargs)
178
179


180
def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
181
    r"""VGG 19-layer model (configuration "E")
182
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
183
    The required minimum input size of the model is 32x32.
184
185
186

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
187
        progress (bool): If True, displays a progress bar of the download to stderr
188
    """
189
    return _vgg("vgg19", "E", False, pretrained, progress, **kwargs)
190
191


192
def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
193
    r"""VGG 19-layer model (configuration 'E') with batch normalization
194
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
195
    The required minimum input size of the model is 32x32.
196
197
198

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
199
        progress (bool): If True, displays a progress bar of the download to stderr
200
    """
201
    return _vgg("vgg19_bn", "E", True, pretrained, progress, **kwargs)