"git@developer.sourcefind.cn:OpenDAS/torchani.git" did not exist on "1d64bbae4c6b6f1245df370b3975fab8793e117d"
Commit fd08315f authored by Sam Gross's avatar Sam Gross Committed by Sam Gross
Browse files

Add inline documentation for models

Also add pre-trained ResNet-152 model.

  ResNet-152: Prec@1 78.312 Prec@5 94.046
parent a919deb3
"""The models subpackage contains definitions for the following model
architectures:
- `AlexNet`_
- `VGG`_
- `ResNet`_
You can construct a model with random weights by calling its constructor:
.. code:: python
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
We provide pre-trained models for the ResNet variants and AlexNet, using the
PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing
``pretrained=True``:
.. code:: python
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
.. _AlexNet: https://arxiv.org/abs/1404.5997
.. _VGG: https://arxiv.org/abs/1409.1556
.. _ResNet: https://arxiv.org/abs/1512.03385
"""
from .alexnet import *
from .resnet import *
from .vgg import *
......@@ -46,8 +46,11 @@ class AlexNet(nn.Module):
def alexnet(pretrained=False):
r"""AlexNet model architecture from the "One weird trick" paper.
https://arxiv.org/abs/1404.5997
r"""AlexNet model architecture from the
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = AlexNet()
if pretrained:
......
......@@ -12,6 +12,7 @@ model_urls = {
'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth',
'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth',
'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth',
}
......@@ -152,6 +153,11 @@ class ResNet(nn.Module):
def resnet18(pretrained=False):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2])
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
......@@ -159,6 +165,11 @@ def resnet18(pretrained=False):
def resnet34(pretrained=False):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3])
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
......@@ -166,6 +177,11 @@ def resnet34(pretrained=False):
def resnet50(pretrained=False):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3])
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
......@@ -173,11 +189,24 @@ def resnet50(pretrained=False):
def resnet101(pretrained=False):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3])
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
def resnet152():
return ResNet(Bottleneck, [3, 8, 36, 3])
def resnet152(pretrained=False):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3])
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model
......@@ -53,32 +53,40 @@ cfg = {
def vgg11():
"""VGG 11-layer model (configuration "A")"""
return VGG(make_layers(cfg['A']))
def vgg11_bn():
"""VGG 11-layer model (configuration "A") with batch normalization"""
return VGG(make_layers(cfg['A'], batch_norm=True))
def vgg13():
"""VGG 13-layer model (configuration "B")"""
return VGG(make_layers(cfg['B']))
def vgg13_bn():
"""VGG 13-layer model (configuration "B") with batch normalization"""
return VGG(make_layers(cfg['B'], batch_norm=True))
def vgg16():
"""VGG 11-layer model (configuration "B")"""
return VGG(make_layers(cfg['D']))
def vgg16_bn():
"""VGG 16-layer model (configuration "D") with batch normalization"""
return VGG(make_layers(cfg['D'], batch_norm=True))
def vgg19():
"""VGG 19-layer model (configuration "D")"""
return VGG(make_layers(cfg['E']))
def vgg19_bn():
"""VGG 19-layer model (configuration 'E') with batch normalization"""
return VGG(make_layers(cfg['E'], batch_norm=True))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment