Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d359dfdf
Commit
d359dfdf
authored
Feb 23, 2017
by
Luke Yeager
Committed by
Adam Paszke
Feb 24, 2017
Browse files
Expose the num_classes argument when making models
parent
df75fa63
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
32 additions
and
32 deletions
+32
-32
torchvision/models/alexnet.py
torchvision/models/alexnet.py
+2
-2
torchvision/models/resnet.py
torchvision/models/resnet.py
+10
-10
torchvision/models/squeezenet.py
torchvision/models/squeezenet.py
+4
-4
torchvision/models/vgg.py
torchvision/models/vgg.py
+16
-16
No files found.
torchvision/models/alexnet.py
View file @
d359dfdf
...
@@ -45,14 +45,14 @@ class AlexNet(nn.Module):
...
@@ -45,14 +45,14 @@ class AlexNet(nn.Module):
return
x
return
x
def
alexnet
(
pretrained
=
False
):
def
alexnet
(
pretrained
=
False
,
**
kwargs
):
r
"""AlexNet model architecture from the
r
"""AlexNet model architecture from the
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
model
=
AlexNet
()
model
=
AlexNet
(
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'alexnet'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'alexnet'
]))
return
model
return
model
torchvision/models/resnet.py
View file @
d359dfdf
...
@@ -152,61 +152,61 @@ class ResNet(nn.Module):
...
@@ -152,61 +152,61 @@ class ResNet(nn.Module):
return
x
return
x
def
resnet18
(
pretrained
=
False
):
def
resnet18
(
pretrained
=
False
,
**
kwargs
):
"""Constructs a ResNet-18 model.
"""Constructs a ResNet-18 model.
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
model
=
ResNet
(
BasicBlock
,
[
2
,
2
,
2
,
2
])
model
=
ResNet
(
BasicBlock
,
[
2
,
2
,
2
,
2
]
,
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'resnet18'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'resnet18'
]))
return
model
return
model
def
resnet34
(
pretrained
=
False
):
def
resnet34
(
pretrained
=
False
,
**
kwargs
):
"""Constructs a ResNet-34 model.
"""Constructs a ResNet-34 model.
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
model
=
ResNet
(
BasicBlock
,
[
3
,
4
,
6
,
3
])
model
=
ResNet
(
BasicBlock
,
[
3
,
4
,
6
,
3
]
,
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'resnet34'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'resnet34'
]))
return
model
return
model
def
resnet50
(
pretrained
=
False
):
def
resnet50
(
pretrained
=
False
,
**
kwargs
):
"""Constructs a ResNet-50 model.
"""Constructs a ResNet-50 model.
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
model
=
ResNet
(
Bottleneck
,
[
3
,
4
,
6
,
3
])
model
=
ResNet
(
Bottleneck
,
[
3
,
4
,
6
,
3
]
,
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'resnet50'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'resnet50'
]))
return
model
return
model
def
resnet101
(
pretrained
=
False
):
def
resnet101
(
pretrained
=
False
,
**
kwargs
):
"""Constructs a ResNet-101 model.
"""Constructs a ResNet-101 model.
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
model
=
ResNet
(
Bottleneck
,
[
3
,
4
,
23
,
3
])
model
=
ResNet
(
Bottleneck
,
[
3
,
4
,
23
,
3
]
,
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'resnet101'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'resnet101'
]))
return
model
return
model
def
resnet152
(
pretrained
=
False
):
def
resnet152
(
pretrained
=
False
,
**
kwargs
):
"""Constructs a ResNet-152 model.
"""Constructs a ResNet-152 model.
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
model
=
ResNet
(
Bottleneck
,
[
3
,
8
,
36
,
3
])
model
=
ResNet
(
Bottleneck
,
[
3
,
8
,
36
,
3
]
,
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'resnet152'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'resnet152'
]))
return
model
return
model
torchvision/models/squeezenet.py
View file @
d359dfdf
...
@@ -101,7 +101,7 @@ class SqueezeNet(nn.Module):
...
@@ -101,7 +101,7 @@ class SqueezeNet(nn.Module):
return
x
.
view
(
x
.
size
(
0
),
self
.
num_classes
)
return
x
.
view
(
x
.
size
(
0
),
self
.
num_classes
)
def
squeezenet1_0
(
pretrained
=
False
):
def
squeezenet1_0
(
pretrained
=
False
,
**
kwargs
):
r
"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
r
"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
accuracy with 50x fewer parameters and <0.5MB model size"
accuracy with 50x fewer parameters and <0.5MB model size"
<https://arxiv.org/abs/1602.07360>`_ paper.
<https://arxiv.org/abs/1602.07360>`_ paper.
...
@@ -109,13 +109,13 @@ def squeezenet1_0(pretrained=False):
...
@@ -109,13 +109,13 @@ def squeezenet1_0(pretrained=False):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
model
=
SqueezeNet
(
version
=
1.0
)
model
=
SqueezeNet
(
version
=
1.0
,
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'squeezenet1_0'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'squeezenet1_0'
]))
return
model
return
model
def
squeezenet1_1
(
pretrained
=
False
):
def
squeezenet1_1
(
pretrained
=
False
,
**
kwargs
):
r
"""SqueezeNet 1.1 model from the `official SqueezeNet repo
r
"""SqueezeNet 1.1 model from the `official SqueezeNet repo
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
...
@@ -124,7 +124,7 @@ def squeezenet1_1(pretrained=False):
...
@@ -124,7 +124,7 @@ def squeezenet1_1(pretrained=False):
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
model
=
SqueezeNet
(
version
=
1.1
)
model
=
SqueezeNet
(
version
=
1.1
,
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'squeezenet1_1'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'squeezenet1_1'
]))
return
model
return
model
torchvision/models/vgg.py
View file @
d359dfdf
...
@@ -78,69 +78,69 @@ cfg = {
...
@@ -78,69 +78,69 @@ cfg = {
}
}
def
vgg11
(
pretrained
=
False
):
def
vgg11
(
pretrained
=
False
,
**
kwargs
):
"""VGG 11-layer model (configuration "A")
"""VGG 11-layer model (configuration "A")
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
model
=
VGG
(
make_layers
(
cfg
[
'A'
]))
model
=
VGG
(
make_layers
(
cfg
[
'A'
])
,
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg11'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg11'
]))
return
model
return
model
def
vgg11_bn
():
def
vgg11_bn
(
**
kwargs
):
"""VGG 11-layer model (configuration "A") with batch normalization"""
"""VGG 11-layer model (configuration "A") with batch normalization"""
return
VGG
(
make_layers
(
cfg
[
'A'
],
batch_norm
=
True
))
return
VGG
(
make_layers
(
cfg
[
'A'
],
batch_norm
=
True
)
,
**
kwargs
)
def
vgg13
(
pretrained
=
False
):
def
vgg13
(
pretrained
=
False
,
**
kwargs
):
"""VGG 13-layer model (configuration "B")
"""VGG 13-layer model (configuration "B")
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
model
=
VGG
(
make_layers
(
cfg
[
'B'
]))
model
=
VGG
(
make_layers
(
cfg
[
'B'
])
,
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg13'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg13'
]))
return
model
return
model
def
vgg13_bn
():
def
vgg13_bn
(
**
kwargs
):
"""VGG 13-layer model (configuration "B") with batch normalization"""
"""VGG 13-layer model (configuration "B") with batch normalization"""
return
VGG
(
make_layers
(
cfg
[
'B'
],
batch_norm
=
True
))
return
VGG
(
make_layers
(
cfg
[
'B'
],
batch_norm
=
True
)
,
**
kwargs
)
def
vgg16
(
pretrained
=
False
):
def
vgg16
(
pretrained
=
False
,
**
kwargs
):
"""VGG 16-layer model (configuration "D")
"""VGG 16-layer model (configuration "D")
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
model
=
VGG
(
make_layers
(
cfg
[
'D'
]))
model
=
VGG
(
make_layers
(
cfg
[
'D'
])
,
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg16'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg16'
]))
return
model
return
model
def
vgg16_bn
():
def
vgg16_bn
(
**
kwargs
):
"""VGG 16-layer model (configuration "D") with batch normalization"""
"""VGG 16-layer model (configuration "D") with batch normalization"""
return
VGG
(
make_layers
(
cfg
[
'D'
],
batch_norm
=
True
))
return
VGG
(
make_layers
(
cfg
[
'D'
],
batch_norm
=
True
)
,
**
kwargs
)
def
vgg19
(
pretrained
=
False
):
def
vgg19
(
pretrained
=
False
,
**
kwargs
):
"""VGG 19-layer model (configuration "E")
"""VGG 19-layer model (configuration "E")
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
"""
model
=
VGG
(
make_layers
(
cfg
[
'E'
]))
model
=
VGG
(
make_layers
(
cfg
[
'E'
])
,
**
kwargs
)
if
pretrained
:
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg19'
]))
model
.
load_state_dict
(
model_zoo
.
load_url
(
model_urls
[
'vgg19'
]))
return
model
return
model
def
vgg19_bn
():
def
vgg19_bn
(
**
kwargs
):
"""VGG 19-layer model (configuration 'E') with batch normalization"""
"""VGG 19-layer model (configuration 'E') with batch normalization"""
return
VGG
(
make_layers
(
cfg
[
'E'
],
batch_norm
=
True
))
return
VGG
(
make_layers
(
cfg
[
'E'
],
batch_norm
=
True
)
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment