Commit b3238191 authored by pangjm's avatar pangjm
Browse files

minor revision

parent de70fe5c
...@@ -6,6 +6,19 @@ import mmcv ...@@ -6,6 +6,19 @@ import mmcv
import torch import torch
from torch.utils import model_zoo from torch.utils import model_zoo
open_mmlab_model_urls = {
'vgg16_caffe':
'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth',
'resnet50_caffe':
'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_caffe-788b5fa3.pth',
'resnet101_caffe':
'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_caffe-3ad79236.pth',
'resnext101_32x4d':
'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d-a5af3160.pth',
'resnext101_64x4d':
'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth'
}
def load_state_dict(module, state_dict, strict=False, logger=None): def load_state_dict(module, state_dict, strict=False, logger=None):
"""Load state_dict to a module. """Load state_dict to a module.
...@@ -80,11 +93,19 @@ def load_checkpoint(model, ...@@ -80,11 +93,19 @@ def load_checkpoint(model,
""" """
# load checkpoint from modelzoo or file or url # load checkpoint from modelzoo or file or url
if filename.startswith('modelzoo://'): if filename.startswith('modelzoo://'):
from torchvision.models.resnet import model_urls import torchvision
import pkgutil
import inspect
model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(
torchvision.models.__path__):
_zoo = getattr(torchvision.models, name)
if inspect.ismodule(_zoo):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
model_name = filename[11:] model_name = filename[11:]
checkpoint = model_zoo.load_url(model_urls[model_name]) checkpoint = model_zoo.load_url(model_urls[model_name])
elif filename.startswith('open-mmlab://'): elif filename.startswith('open-mmlab://'):
from .urls import open_mmlab_model_urls
model_name = filename[13:] model_name = filename[13:]
checkpoint = model_zoo.load_url(open_mmlab_model_urls[model_name]) checkpoint = model_zoo.load_url(open_mmlab_model_urls[model_name])
elif filename.startswith(('http://', 'https://')): elif filename.startswith(('http://', 'https://')):
......
open_mmlab_model_urls = {
'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth',
'resnet50_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_caffe-788b5fa3.pth',
'resnet101_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_caffe-3ad79236.pth',
'resnext101_32x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d-a5af3160.pth',
'resnext101_64x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth'
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment