Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
a2e6b709
Commit
a2e6b709
authored
Apr 16, 2019
by
Ailing
Committed by
Francisco Massa
Apr 16, 2019
Browse files
More hub models (#851)
* add all available models * simplify vision hubconf
parent
74b6a750
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
64 deletions
+8
-64
hubconf.py
hubconf.py
+8
-64
No files found.
hubconf.py
View file @
a2e6b709
'''
This is an example hubconf.py for pytorch/vision repo
## Users can get this published model by calling:
hub_model = hub.load(
'pytorch/vision:master', # repo_owner/repo_name:branch
'resnet18', # entrypoint
1234, # args for callable [not applicable to resnet]
pretrained=True) # kwargs for callable
## Protocol on repo owner side
1. The "published" models should be at least in a branch/tag. It can't be a random commit.
2. Repo owner should have the following field defined in hubconf.py
2.1 Function/entrypoint with function signature "def resnet18(pretrained=False, *args, **kwargs):"
2.2 Pretrained allows users to load pretrained weights from repo owner.
2.3 Args and kwargs are passed to the callable _resnet18,
2.4 Docstring of function works as a help message, explaining what does the model do and what's
the allowed arguments.
2.5 Dependencies is a list optionally provided by repo owner, to specify what packages are required
to run the model.
## Hub_dir
hub_dir specifies where the intermediate files/folders will be saved. By default this is ~/.torch/hub.
Users can change it by either setting the environment variable TORCH_HUB_DIR or calling hub.set_dir(PATH_TO_HUB_DIR).
By default, we don't cleanup files after loading so that users can use cache next time.
## Cache logic
We used the cache by default if it exists in hub_dir.
Users can force a fresh reload by calling hub.load(..., force_reload=True).
'''
import
torch.utils.model_zoo
as
model_zoo
# Optional list of dependencies required by the package
dependencies
=
[
'torch'
,
'math'
]
def
resnet18
(
pretrained
=
False
,
*
args
,
**
kwargs
):
"""
Resnet18 model
pretrained (bool): a recommended kwargs for all entrypoints
args & kwargs are arguments for the function
"""
from
torchvision.models.resnet
import
resnet18
as
_resnet18
model
=
_resnet18
(
*
args
,
**
kwargs
)
checkpoint
=
'https://download.pytorch.org/models/resnet18-5c106cde.pth'
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
checkpoint
,
progress
=
False
))
return
model
def
resnet50
(
pretrained
=
False
,
*
args
,
**
kwargs
):
"""
Resnet50 model
pretrained (bool): a recommended kwargs for all entrypoints
args & kwargs are arguments for the function
"""
from
torchvision.models.resnet
import
resnet50
as
_resnet50
model
=
_resnet50
(
*
args
,
**
kwargs
)
checkpoint
=
'https://download.pytorch.org/models/resnet50-19c8e357.pth'
if
pretrained
:
model
.
load_state_dict
(
model_zoo
.
load_url
(
checkpoint
,
progress
=
False
))
return
model
dependencies
=
[
'torch'
]
from
torchvision.models.alexnet
import
alexnet
from
torchvision.models.densenet
import
densenet121
,
densenet169
,
densenet201
,
densenet161
from
torchvision.models.inception
import
inception_v3
from
torchvision.models.resnet
import
resnet18
,
resnet34
,
resnet50
,
resnet101
,
resnet152
from
torchvision.models.squeezenet
import
squeezenet1_0
,
squeezenet1_1
from
torchvision.models.vgg
import
vgg11
,
vgg13
,
vgg16
,
vgg19
,
vgg11_bn
,
vgg13_bn
,
vgg16_bn
,
vgg19_bn
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment