Commit fa2836c2 authored by Sam Gross's avatar Sam Gross Committed by Soumith Chintala
Browse files

Add pre-trained VGG models with batch normalization (#178)

Fixes #152
parent 83263d85
...@@ -29,27 +29,31 @@ PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing ...@@ -29,27 +29,31 @@ PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing
ImageNet 1-crop error rates (224x224) ImageNet 1-crop error rates (224x224)
======================== ============= ============= ================================ ============= =============
Network Top-1 error Top-5 error Network Top-1 error Top-5 error
======================== ============= ============= ================================ ============= =============
ResNet-18 30.24 10.92 ResNet-18 30.24 10.92
ResNet-34 26.70 8.58 ResNet-34 26.70 8.58
ResNet-50 23.85 7.13 ResNet-50 23.85 7.13
ResNet-101 22.63 6.44 ResNet-101 22.63 6.44
ResNet-152 21.69 5.94 ResNet-152 21.69 5.94
Inception v3 22.55 6.44 Inception v3 22.55 6.44
AlexNet 43.45 20.91 AlexNet 43.45 20.91
VGG-11 30.98 11.37 VGG-11 30.98 11.37
VGG-13 30.07 10.75 VGG-13 30.07 10.75
VGG-16 28.41 9.62 VGG-16 28.41 9.62
VGG-19 27.62 9.12 VGG-19 27.62 9.12
SqueezeNet 1.0 41.90 19.58 VGG-11 with batch normalization 29.62 10.19
SqueezeNet 1.1 41.81 19.38 VGG-13 with batch normalization 28.45 9.63
Densenet-121 25.35 7.83 VGG-16 with batch normalization 26.63 8.50
Densenet-169 24.00 7.00 VGG-19 with batch normalization 25.76 8.15
Densenet-201 22.80 6.43 SqueezeNet 1.0 41.90 19.58
Densenet-161 22.35 6.20 SqueezeNet 1.1 41.81 19.38
======================== ============= ============= Densenet-121 25.35 7.83
Densenet-169 24.00 7.00
Densenet-201 22.80 6.43
Densenet-161 22.35 6.20
================================ ============= =============
.. _AlexNet: https://arxiv.org/abs/1404.5997 .. _AlexNet: https://arxiv.org/abs/1404.5997
......
...@@ -14,6 +14,10 @@ model_urls = { ...@@ -14,6 +14,10 @@ model_urls = {
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
} }
...@@ -91,9 +95,16 @@ def vgg11(pretrained=False, **kwargs): ...@@ -91,9 +95,16 @@ def vgg11(pretrained=False, **kwargs):
return model return model
def vgg11_bn(**kwargs): def vgg11_bn(pretrained=False, **kwargs):
"""VGG 11-layer model (configuration "A") with batch normalization""" """VGG 11-layer model (configuration "A") with batch normalization
return VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
return model
def vgg13(pretrained=False, **kwargs): def vgg13(pretrained=False, **kwargs):
...@@ -108,9 +119,16 @@ def vgg13(pretrained=False, **kwargs): ...@@ -108,9 +119,16 @@ def vgg13(pretrained=False, **kwargs):
return model return model
def vgg13_bn(**kwargs): def vgg13_bn(pretrained=False, **kwargs):
"""VGG 13-layer model (configuration "B") with batch normalization""" """VGG 13-layer model (configuration "B") with batch normalization
return VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
return model
def vgg16(pretrained=False, **kwargs): def vgg16(pretrained=False, **kwargs):
...@@ -125,9 +143,16 @@ def vgg16(pretrained=False, **kwargs): ...@@ -125,9 +143,16 @@ def vgg16(pretrained=False, **kwargs):
return model return model
def vgg16_bn(**kwargs): def vgg16_bn(pretrained=False, **kwargs):
"""VGG 16-layer model (configuration "D") with batch normalization""" """VGG 16-layer model (configuration "D") with batch normalization
return VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
return model
def vgg19(pretrained=False, **kwargs): def vgg19(pretrained=False, **kwargs):
...@@ -142,6 +167,13 @@ def vgg19(pretrained=False, **kwargs): ...@@ -142,6 +167,13 @@ def vgg19(pretrained=False, **kwargs):
return model return model
def vgg19_bn(**kwargs): def vgg19_bn(pretrained=False, **kwargs):
"""VGG 19-layer model (configuration 'E') with batch normalization""" """VGG 19-layer model (configuration 'E') with batch normalization
return VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment