Commit d359dfdf authored by Luke Yeager's avatar Luke Yeager Committed by Adam Paszke
Browse files

Expose the num_classes argument when making models

parent df75fa63
......@@ -45,14 +45,14 @@ class AlexNet(nn.Module):
return x
def alexnet(pretrained=False):
def alexnet(pretrained=False, **kwargs):
r"""AlexNet model architecture from the
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = AlexNet()
model = AlexNet(**kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))
return model
......@@ -152,61 +152,61 @@ class ResNet(nn.Module):
return x
def resnet18(pretrained=False):
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2])
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model
def resnet34(pretrained=False):
def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3])
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
return model
def resnet50(pretrained=False):
def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3])
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
def resnet101(pretrained=False):
def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3])
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
def resnet152(pretrained=False):
def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3])
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model
......@@ -101,7 +101,7 @@ class SqueezeNet(nn.Module):
return x.view(x.size(0), self.num_classes)
def squeezenet1_0(pretrained=False):
def squeezenet1_0(pretrained=False, **kwargs):
r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
accuracy with 50x fewer parameters and <0.5MB model size"
<https://arxiv.org/abs/1602.07360>`_ paper.
......@@ -109,13 +109,13 @@ def squeezenet1_0(pretrained=False):
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = SqueezeNet(version=1.0)
model = SqueezeNet(version=1.0, **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0']))
return model
def squeezenet1_1(pretrained=False):
def squeezenet1_1(pretrained=False, **kwargs):
r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
......@@ -124,7 +124,7 @@ def squeezenet1_1(pretrained=False):
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = SqueezeNet(version=1.1)
model = SqueezeNet(version=1.1, **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1']))
return model
......@@ -78,69 +78,69 @@ cfg = {
}
def vgg11(pretrained=False):
def vgg11(pretrained=False, **kwargs):
"""VGG 11-layer model (configuration "A")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = VGG(make_layers(cfg['A']))
model = VGG(make_layers(cfg['A']), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
return model
def vgg11_bn():
def vgg11_bn(**kwargs):
"""VGG 11-layer model (configuration "A") with batch normalization"""
return VGG(make_layers(cfg['A'], batch_norm=True))
return VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
def vgg13(pretrained=False):
def vgg13(pretrained=False, **kwargs):
"""VGG 13-layer model (configuration "B")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = VGG(make_layers(cfg['B']))
model = VGG(make_layers(cfg['B']), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
return model
def vgg13_bn():
def vgg13_bn(**kwargs):
"""VGG 13-layer model (configuration "B") with batch normalization"""
return VGG(make_layers(cfg['B'], batch_norm=True))
return VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
def vgg16(pretrained=False):
def vgg16(pretrained=False, **kwargs):
"""VGG 16-layer model (configuration "D")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = VGG(make_layers(cfg['D']))
model = VGG(make_layers(cfg['D']), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
return model
def vgg16_bn():
def vgg16_bn(**kwargs):
"""VGG 16-layer model (configuration "D") with batch normalization"""
return VGG(make_layers(cfg['D'], batch_norm=True))
return VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
def vgg19(pretrained=False):
def vgg19(pretrained=False, **kwargs):
"""VGG 19-layer model (configuration "E")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = VGG(make_layers(cfg['E']))
model = VGG(make_layers(cfg['E']), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
return model
def vgg19_bn():
def vgg19_bn(**kwargs):
"""VGG 19-layer model (configuration 'E') with batch normalization"""
return VGG(make_layers(cfg['E'], batch_norm=True))
return VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment