Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
5a4620bc
Unverified
Commit
5a4620bc
authored
Dec 19, 2018
by
Kai Chen
Committed by
GitHub
Dec 19, 2018
Browse files
Merge pull request #27 from OceanPang/master
add open-mmlab urls
parents
514b8f8a
c2d17dad
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
2 deletions
+23
-2
mmcv/runner/checkpoint.py
mmcv/runner/checkpoint.py
+23
-2
No files found.
mmcv/runner/checkpoint.py
View file @
5a4620bc
import
os.path
as
osp
import
pkgutil
import
time
from
collections
import
OrderedDict
from
importlib
import
import_module
import
mmcv
import
torch
from
torch.utils
import
model_zoo
open_mmlab_model_urls
=
{
'vgg16_caffe'
:
'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth'
,
# noqa: E501
'resnet50_caffe'
:
'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_caffe-788b5fa3.pth'
,
# noqa: E501
'resnet101_caffe'
:
'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_caffe-3ad79236.pth'
,
# noqa: E501
'resnext101_32x4d'
:
'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d-a5af3160.pth'
,
# noqa: E501
'resnext101_64x4d'
:
'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth'
# noqa: E501
}
def
load_state_dict
(
module
,
state_dict
,
strict
=
False
,
logger
=
None
):
"""Load state_dict to a module.
...
...
@@ -69,7 +80,7 @@ def load_checkpoint(model,
Args:
model (Module): Module to load checkpoint.
filename (str): Either a filepath or URL or modelzo
ll
://xxxxxxx.
filename (str): Either a filepath or URL or modelzo
o
://xxxxxxx.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
...
...
@@ -80,9 +91,19 @@ def load_checkpoint(model,
"""
# load checkpoint from modelzoo or file or url
if
filename
.
startswith
(
'modelzoo://'
):
from
torchvision.models.resnet
import
model_urls
import
torchvision
model_urls
=
dict
()
for
_
,
name
,
ispkg
in
pkgutil
.
walk_packages
(
torchvision
.
models
.
__path__
):
if
not
ispkg
:
_zoo
=
import_module
(
'torchvision.models.{}'
.
format
(
name
))
_urls
=
getattr
(
_zoo
,
'model_urls'
)
model_urls
.
update
(
_urls
)
model_name
=
filename
[
11
:]
checkpoint
=
model_zoo
.
load_url
(
model_urls
[
model_name
])
elif
filename
.
startswith
(
'open-mmlab://'
):
model_name
=
filename
[
13
:]
checkpoint
=
model_zoo
.
load_url
(
open_mmlab_model_urls
[
model_name
])
elif
filename
.
startswith
((
'http://'
,
'https://'
)):
checkpoint
=
model_zoo
.
load_url
(
filename
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment