Commit e6196137 authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

Added progress flag to model getters (#875)

* added progress flag to model getters

* flake8

* bug fix

* backward commpability
parent 830df55d
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from .utils import load_state_dict_from_url
__all__ = ['AlexNet', 'alexnet']
......@@ -48,14 +48,17 @@ class AlexNet(nn.Module):
return x
def alexnet(pretrained=False, **kwargs):
def alexnet(pretrained=False, progress=True, **kwargs):
r"""AlexNet model architecture from the
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = AlexNet(**kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))
state_dict = load_state_dict_from_url(model_urls['alexnet'],
progress=progress)
model.load_state_dict(state_dict)
return model
......@@ -2,12 +2,11 @@ import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from .utils import load_state_dict_from_url
from collections import OrderedDict
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
......@@ -22,17 +21,20 @@ class _DenseLayer(nn.Sequential):
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
growth_rate, kernel_size=1, stride=1, bias=False)),
growth_rate, kernel_size=1, stride=1,
bias=False)),
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu2', nn.ReLU(inplace=True)),
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False)),
kernel_size=3, stride=1, padding=1,
bias=False)),
self.drop_rate = drop_rate
def forward(self, x):
new_features = super(_DenseLayer, self).forward(x)
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
new_features = F.dropout(new_features, p=self.drop_rate,
training=self.training)
return torch.cat([x, new_features], 1)
......@@ -40,7 +42,8 @@ class _DenseBlock(nn.Sequential):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate,
bn_size, drop_rate)
self.add_module('denselayer%d' % (i + 1), layer)
......@@ -75,7 +78,8 @@ class DenseNet(nn.Module):
# First convolution
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
padding=3, bias=False)),
('norm0', nn.BatchNorm2d(num_init_features)),
('relu0', nn.ReLU(inplace=True)),
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
......@@ -85,11 +89,13 @@ class DenseNet(nn.Module):
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
bn_size=bn_size, growth_rate=growth_rate,
drop_rate=drop_rate)
self.features.add_module('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
trans = _Transition(num_input_features=num_features,
num_output_features=num_features // 2)
self.features.add_module('transition%d' % (i + 1), trans)
num_features = num_features // 2
......@@ -117,14 +123,15 @@ class DenseNet(nn.Module):
return out
def _load_state_dict(model, model_url):
def _load_state_dict(model, model_url, progress):
# '.'s are no longer allowed in module names, but previous _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_url)
state_dict = load_state_dict_from_url(model_url, progress=progress)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
......@@ -134,57 +141,57 @@ def _load_state_dict(model, model_url):
model.load_state_dict(state_dict)
def densenet121(pretrained=False, **kwargs):
def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress,
**kwargs):
model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
if pretrained:
_load_state_dict(model, model_urls[arch], progress)
return model
def densenet121(pretrained=False, progress=True, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress,
**kwargs)
if pretrained:
_load_state_dict(model, model_urls['densenet121'])
return model
def densenet169(pretrained=False, **kwargs):
r"""Densenet-169 model from
def densenet161(pretrained=False, progress=True, **kwargs):
r"""Densenet-161 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress,
**kwargs)
if pretrained:
_load_state_dict(model, model_urls['densenet169'])
return model
def densenet201(pretrained=False, **kwargs):
r"""Densenet-201 model from
def densenet169(pretrained=False, progress=True, **kwargs):
r"""Densenet-169 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress,
**kwargs)
if pretrained:
_load_state_dict(model, model_urls['densenet201'])
return model
def densenet161(pretrained=False, **kwargs):
r"""Densenet-161 model from
def densenet201(pretrained=False, progress=True, **kwargs):
r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress,
**kwargs)
if pretrained:
_load_state_dict(model, model_urls['densenet161'])
return model
......@@ -3,7 +3,7 @@ from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import model_zoo
from .utils import load_state_dict_from_url
__all__ = ['GoogLeNet', 'googlenet']
......@@ -15,12 +15,13 @@ model_urls = {
_GoogLeNetOuputs = namedtuple('GoogLeNetOuputs', ['logits', 'aux_logits2', 'aux_logits1'])
def googlenet(pretrained=False, **kwargs):
def googlenet(pretrained=False, progress=True, **kwargs):
r"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
aux_logits (bool): If True, adds two auxiliary branches that can improve training.
Default: *False* when pretrained is True otherwise *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
......@@ -38,7 +39,9 @@ def googlenet(pretrained=False, **kwargs):
kwargs['aux_logits'] = True
kwargs['init_weights'] = False
model = GoogLeNet(**kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['googlenet']))
state_dict = load_state_dict_from_url(model_urls['googlenet'],
progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
del model.aux1, model.aux2
......
......@@ -2,7 +2,7 @@ from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from .utils import load_state_dict_from_url
__all__ = ['Inception3', 'inception_v3']
......@@ -16,7 +16,7 @@ model_urls = {
_InceptionOuputs = namedtuple('InceptionOuputs', ['logits', 'aux_logits'])
def inception_v3(pretrained=False, **kwargs):
def inception_v3(pretrained=False, progress=True, **kwargs):
r"""Inception v3 model architecture from
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
......@@ -26,6 +26,7 @@ def inception_v3(pretrained=False, **kwargs):
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
aux_logits (bool): If True, add an auxiliary branch that can improve training.
Default: *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
......@@ -40,7 +41,9 @@ def inception_v3(pretrained=False, **kwargs):
else:
original_aux_logits = True
model = Inception3(**kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google']))
state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
del model.AuxLogits
......
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from .utils import load_state_dict_from_url
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
......@@ -204,75 +204,79 @@ class ResNet(nn.Module):
return x
def resnet18(pretrained=False, **kwargs):
def _resnet(arch, inplanes, planes, pretrained, progress, **kwargs):
model = ResNet(inplanes, planes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def resnet18(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34(pretrained=False, **kwargs):
def resnet34(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
return model
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50(pretrained=False, **kwargs):
def resnet50(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet101(pretrained=False, **kwargs):
def resnet101(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)
def resnet152(pretrained=False, **kwargs):
def resnet152(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)
def resnext50_32x4d(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, **kwargs)
# if pretrained:
# model.load_state_dict(model_zoo.load_url(model_urls['resnext50_32x4d']))
return model
def resnext50_32x4d(**kwargs):
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained=False, progress=True, **kwargs)
def resnext101_32x8d(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs)
# if pretrained:
# model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d']))
return model
def resnext101_32x8d(**kwargs):
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained=False, progress=True, **kwargs)
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.utils.model_zoo as model_zoo
from .utils import load_state_dict_from_url
__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
model_urls = {
'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',
'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
......@@ -38,13 +36,10 @@ class Fire(nn.Module):
class SqueezeNet(nn.Module):
def __init__(self, version=1.0, num_classes=1000):
def __init__(self, version='1_0', num_classes=1000):
super(SqueezeNet, self).__init__()
if version not in [1.0, 1.1]:
raise ValueError("Unsupported SqueezeNet version {version}:"
"1.0 or 1.1 expected".format(version=version))
self.num_classes = num_classes
if version == 1.0:
if version == '1_0':
self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=7, stride=2),
nn.ReLU(inplace=True),
......@@ -60,7 +55,7 @@ class SqueezeNet(nn.Module):
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(512, 64, 256, 256),
)
else:
elif version == '1_1':
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2),
nn.ReLU(inplace=True),
......@@ -76,6 +71,13 @@ class SqueezeNet(nn.Module):
Fire(384, 64, 256, 256),
Fire(512, 64, 256, 256),
)
else:
# FIXME: Is this needed? SqueezeNet should only be called from the
# FIXME: squeezenet1_x() functions
# FIXME: This checking is not done for the other models
raise ValueError("Unsupported SqueezeNet version {version}:"
"1_0 or 1_1 expected".format(version=version))
# Final convolution is initialized differently from the rest
final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
self.classifier = nn.Sequential(
......@@ -100,21 +102,29 @@ class SqueezeNet(nn.Module):
return x.view(x.size(0), self.num_classes)
def squeezenet1_0(pretrained=False, **kwargs):
def _squeezenet(version, pretrained, progress, **kwargs):
model = SqueezeNet(version, **kwargs)
if pretrained:
arch = 'squeezenet' + version
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def squeezenet1_0(pretrained=False, progress=True, **kwargs):
r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
accuracy with 50x fewer parameters and <0.5MB model size"
<https://arxiv.org/abs/1602.07360>`_ paper.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = SqueezeNet(version=1.0, **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0']))
return model
return _squeezenet('1_0', pretrained, progress, **kwargs)
def squeezenet1_1(pretrained=False, **kwargs):
def squeezenet1_1(pretrained=False, progress=True, **kwargs):
r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
......@@ -122,8 +132,6 @@ def squeezenet1_1(pretrained=False, **kwargs):
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = SqueezeNet(version=1.1, **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1']))
return model
return _squeezenet('1_1', pretrained, progress, **kwargs)
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from .utils import load_state_dict_from_url
__all__ = [
......@@ -75,7 +75,7 @@ def make_layers(cfg, batch_norm=False):
return nn.Sequential(*layers)
cfg = {
cfgs = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
......@@ -83,113 +83,92 @@ cfg = {
}
def vgg11(pretrained=False, **kwargs):
def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
if pretrained:
kwargs['init_weights'] = False
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def vgg11(pretrained=False, progress=True, **kwargs):
"""VGG 11-layer model (configuration "A")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained:
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['A']), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
return model
return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
def vgg11_bn(pretrained=False, **kwargs):
def vgg11_bn(pretrained=False, progress=True, **kwargs):
"""VGG 11-layer model (configuration "A") with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained:
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
return model
return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
def vgg13(pretrained=False, **kwargs):
def vgg13(pretrained=False, progress=True, **kwargs):
"""VGG 13-layer model (configuration "B")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained:
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['B']), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
return model
return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
def vgg13_bn(pretrained=False, **kwargs):
def vgg13_bn(pretrained=False, progress=True, **kwargs):
"""VGG 13-layer model (configuration "B") with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained:
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
return model
return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
def vgg16(pretrained=False, **kwargs):
def vgg16(pretrained=False, progress=True, **kwargs):
"""VGG 16-layer model (configuration "D")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained:
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['D']), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
return model
return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
def vgg16_bn(pretrained=False, **kwargs):
def vgg16_bn(pretrained=False, progress=True, **kwargs):
"""VGG 16-layer model (configuration "D") with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained:
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
return model
return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
def vgg19(pretrained=False, **kwargs):
def vgg19(pretrained=False, progress=True, **kwargs):
"""VGG 19-layer model (configuration "E")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained:
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['E']), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
return model
return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
def vgg19_bn(pretrained=False, **kwargs):
def vgg19_bn(pretrained=False, progress=True, **kwargs):
"""VGG 19-layer model (configuration 'E') with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained:
kwargs['init_weights'] = False
model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
return model
return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment