Unverified Commit 3e1e297d authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

use torchvision:// instead of modelzoo:// (#89)

parent e5ca8846
...@@ -2,11 +2,13 @@ import os ...@@ -2,11 +2,13 @@ import os
import os.path as osp import os.path as osp
import pkgutil import pkgutil
import time import time
import warnings
from collections import OrderedDict from collections import OrderedDict
from importlib import import_module from importlib import import_module
import mmcv import mmcv
import torch import torch
import torchvision
from torch.utils import model_zoo from torch.utils import model_zoo
from .utils import get_dist_info from .utils import get_dist_info
...@@ -85,7 +87,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None): ...@@ -85,7 +87,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
if strict: if strict:
raise RuntimeError(err_msg) raise RuntimeError(err_msg)
elif logger is not None: elif logger is not None:
logger.warn(err_msg) logger.warning(err_msg)
else: else:
print(err_msg) print(err_msg)
...@@ -104,6 +106,18 @@ def load_url_dist(url): ...@@ -104,6 +106,18 @@ def load_url_dist(url):
return checkpoint return checkpoint
def get_torchvision_models():
model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module('torchvision.models.{}'.format(name))
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
return model_urls
def load_checkpoint(model, def load_checkpoint(model,
filename, filename,
map_location=None, map_location=None,
...@@ -124,17 +138,15 @@ def load_checkpoint(model, ...@@ -124,17 +138,15 @@ def load_checkpoint(model,
""" """
# load checkpoint from modelzoo or file or url # load checkpoint from modelzoo or file or url
if filename.startswith('modelzoo://'): if filename.startswith('modelzoo://'):
import torchvision warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
model_urls = dict() 'use "torchvision://" instead')
for _, name, ispkg in pkgutil.walk_packages( model_urls = get_torchvision_models()
torchvision.models.__path__):
if not ispkg:
_zoo = import_module('torchvision.models.{}'.format(name))
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
model_name = filename[11:] model_name = filename[11:]
checkpoint = load_url_dist(model_urls[model_name]) checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith('torchvision://'):
model_urls = get_torchvision_models()
model_name = filename[14:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith('open-mmlab://'): elif filename.startswith('open-mmlab://'):
model_name = filename[13:] model_name = filename[13:]
checkpoint = load_url_dist(open_mmlab_model_urls[model_name]) checkpoint = load_url_dist(open_mmlab_model_urls[model_name])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment