Commit d359dfdf authored by Luke Yeager's avatar Luke Yeager Committed by Adam Paszke
Browse files

Expose the num_classes argument when making models

parent df75fa63
...@@ -45,14 +45,14 @@ class AlexNet(nn.Module): ...@@ -45,14 +45,14 @@ class AlexNet(nn.Module):
return x return x
def alexnet(pretrained=False): def alexnet(pretrained=False, **kwargs):
r"""AlexNet model architecture from the r"""AlexNet model architecture from the
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper. `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = AlexNet() model = AlexNet(**kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))
return model return model
...@@ -152,61 +152,61 @@ class ResNet(nn.Module): ...@@ -152,61 +152,61 @@ class ResNet(nn.Module):
return x return x
def resnet18(pretrained=False): def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model. """Constructs a ResNet-18 model.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = ResNet(BasicBlock, [2, 2, 2, 2]) model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model return model
def resnet34(pretrained=False): def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model. """Constructs a ResNet-34 model.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = ResNet(BasicBlock, [3, 4, 6, 3]) model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
return model return model
def resnet50(pretrained=False): def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model. """Constructs a ResNet-50 model.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = ResNet(Bottleneck, [3, 4, 6, 3]) model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model return model
def resnet101(pretrained=False): def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model. """Constructs a ResNet-101 model.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = ResNet(Bottleneck, [3, 4, 23, 3]) model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model return model
def resnet152(pretrained=False): def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model. """Constructs a ResNet-152 model.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = ResNet(Bottleneck, [3, 8, 36, 3]) model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model return model
...@@ -101,7 +101,7 @@ class SqueezeNet(nn.Module): ...@@ -101,7 +101,7 @@ class SqueezeNet(nn.Module):
return x.view(x.size(0), self.num_classes) return x.view(x.size(0), self.num_classes)
def squeezenet1_0(pretrained=False): def squeezenet1_0(pretrained=False, **kwargs):
r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
accuracy with 50x fewer parameters and <0.5MB model size" accuracy with 50x fewer parameters and <0.5MB model size"
<https://arxiv.org/abs/1602.07360>`_ paper. <https://arxiv.org/abs/1602.07360>`_ paper.
...@@ -109,13 +109,13 @@ def squeezenet1_0(pretrained=False): ...@@ -109,13 +109,13 @@ def squeezenet1_0(pretrained=False):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = SqueezeNet(version=1.0) model = SqueezeNet(version=1.0, **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0'])) model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0']))
return model return model
def squeezenet1_1(pretrained=False): def squeezenet1_1(pretrained=False, **kwargs):
r"""SqueezeNet 1.1 model from the `official SqueezeNet repo r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_. <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
...@@ -124,7 +124,7 @@ def squeezenet1_1(pretrained=False): ...@@ -124,7 +124,7 @@ def squeezenet1_1(pretrained=False):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = SqueezeNet(version=1.1) model = SqueezeNet(version=1.1, **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1'])) model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1']))
return model return model
...@@ -78,69 +78,69 @@ cfg = { ...@@ -78,69 +78,69 @@ cfg = {
} }
def vgg11(pretrained=False): def vgg11(pretrained=False, **kwargs):
"""VGG 11-layer model (configuration "A") """VGG 11-layer model (configuration "A")
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = VGG(make_layers(cfg['A'])) model = VGG(make_layers(cfg['A']), **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
return model return model
def vgg11_bn(): def vgg11_bn(**kwargs):
"""VGG 11-layer model (configuration "A") with batch normalization""" """VGG 11-layer model (configuration "A") with batch normalization"""
return VGG(make_layers(cfg['A'], batch_norm=True)) return VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
def vgg13(pretrained=False): def vgg13(pretrained=False, **kwargs):
"""VGG 13-layer model (configuration "B") """VGG 13-layer model (configuration "B")
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = VGG(make_layers(cfg['B'])) model = VGG(make_layers(cfg['B']), **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
return model return model
def vgg13_bn(): def vgg13_bn(**kwargs):
"""VGG 13-layer model (configuration "B") with batch normalization""" """VGG 13-layer model (configuration "B") with batch normalization"""
return VGG(make_layers(cfg['B'], batch_norm=True)) return VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
def vgg16(pretrained=False): def vgg16(pretrained=False, **kwargs):
"""VGG 16-layer model (configuration "D") """VGG 16-layer model (configuration "D")
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = VGG(make_layers(cfg['D'])) model = VGG(make_layers(cfg['D']), **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
return model return model
def vgg16_bn(): def vgg16_bn(**kwargs):
"""VGG 16-layer model (configuration "D") with batch normalization""" """VGG 16-layer model (configuration "D") with batch normalization"""
return VGG(make_layers(cfg['D'], batch_norm=True)) return VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
def vgg19(pretrained=False): def vgg19(pretrained=False, **kwargs):
"""VGG 19-layer model (configuration "E") """VGG 19-layer model (configuration "E")
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = VGG(make_layers(cfg['E'])) model = VGG(make_layers(cfg['E']), **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
return model return model
def vgg19_bn(): def vgg19_bn(**kwargs):
"""VGG 19-layer model (configuration 'E') with batch normalization""" """VGG 19-layer model (configuration 'E') with batch normalization"""
return VGG(make_layers(cfg['E'], batch_norm=True)) return VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment