Commit fa2836c2 authored by Sam Gross's avatar Sam Gross Committed by Soumith Chintala
Browse files

Add pre-trained VGG models with batch normalization (#178)

Fixes #152
parent 83263d85
......@@ -29,9 +29,9 @@ PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing
ImageNet 1-crop error rates (224x224)
======================== ============= =============
================================ ============= =============
Network Top-1 error Top-5 error
======================== ============= =============
================================ ============= =============
ResNet-18 30.24 10.92
ResNet-34 26.70 8.58
ResNet-50 23.85 7.13
......@@ -43,13 +43,17 @@ VGG-11 30.98 11.37
VGG-13 30.07 10.75
VGG-16 28.41 9.62
VGG-19 27.62 9.12
VGG-11 with batch normalization 29.62 10.19
VGG-13 with batch normalization 28.45 9.63
VGG-16 with batch normalization 26.63 8.50
VGG-19 with batch normalization 25.76 8.15
SqueezeNet 1.0 41.90 19.58
SqueezeNet 1.1 41.81 19.38
Densenet-121 25.35 7.83
Densenet-169 24.00 7.00
Densenet-201 22.80 6.43
Densenet-161 22.35 6.20
======================== ============= =============
================================ ============= =============
.. _AlexNet: https://arxiv.org/abs/1404.5997
......
......@@ -14,6 +14,10 @@ model_urls = {
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}
......@@ -91,9 +95,16 @@ def vgg11(pretrained=False, **kwargs):
return model
def vgg11_bn(**kwargs):
"""VGG 11-layer model (configuration "A") with batch normalization"""
return VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
def vgg11_bn(pretrained=False, **kwargs):
"""VGG 11-layer model (configuration "A") with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
return model
def vgg13(pretrained=False, **kwargs):
......@@ -108,9 +119,16 @@ def vgg13(pretrained=False, **kwargs):
return model
def vgg13_bn(**kwargs):
"""VGG 13-layer model (configuration "B") with batch normalization"""
return VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
def vgg13_bn(pretrained=False, **kwargs):
"""VGG 13-layer model (configuration "B") with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
return model
def vgg16(pretrained=False, **kwargs):
......@@ -125,9 +143,16 @@ def vgg16(pretrained=False, **kwargs):
return model
def vgg16_bn(**kwargs):
"""VGG 16-layer model (configuration "D") with batch normalization"""
return VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
def vgg16_bn(pretrained=False, **kwargs):
"""VGG 16-layer model (configuration "D") with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
return model
def vgg19(pretrained=False, **kwargs):
......@@ -142,6 +167,13 @@ def vgg19(pretrained=False, **kwargs):
return model
def vgg19_bn(**kwargs):
"""VGG 19-layer model (configuration 'E') with batch normalization"""
return VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
def vgg19_bn(pretrained=False, **kwargs):
"""VGG 19-layer model (configuration 'E') with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment