"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a2ecce26bc1181a5ed98a97910a7d0f83efb7538"
Unverified Commit 5a4620bc authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #27 from OceanPang/master

add open-mmlab urls
parents 514b8f8a c2d17dad
import os.path as osp import os.path as osp
import pkgutil
import time import time
from collections import OrderedDict from collections import OrderedDict
from importlib import import_module
import mmcv import mmcv
import torch import torch
from torch.utils import model_zoo from torch.utils import model_zoo
open_mmlab_model_urls = {
'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501
'resnet50_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_caffe-788b5fa3.pth', # noqa: E501
'resnet101_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_caffe-3ad79236.pth', # noqa: E501
'resnext101_32x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d-a5af3160.pth', # noqa: E501
'resnext101_64x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth' # noqa: E501
}
def load_state_dict(module, state_dict, strict=False, logger=None): def load_state_dict(module, state_dict, strict=False, logger=None):
"""Load state_dict to a module. """Load state_dict to a module.
...@@ -69,7 +80,7 @@ def load_checkpoint(model, ...@@ -69,7 +80,7 @@ def load_checkpoint(model,
Args: Args:
model (Module): Module to load checkpoint. model (Module): Module to load checkpoint.
filename (str): Either a filepath or URL or modelzoll://xxxxxxx. filename (str): Either a filepath or URL or modelzoo://xxxxxxx.
map_location (str): Same as :func:`torch.load`. map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and strict (bool): Whether to allow different params for the model and
checkpoint. checkpoint.
...@@ -80,9 +91,19 @@ def load_checkpoint(model, ...@@ -80,9 +91,19 @@ def load_checkpoint(model,
""" """
# load checkpoint from modelzoo or file or url # load checkpoint from modelzoo or file or url
if filename.startswith('modelzoo://'): if filename.startswith('modelzoo://'):
from torchvision.models.resnet import model_urls import torchvision
model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(
torchvision.models.__path__):
if not ispkg:
_zoo = import_module('torchvision.models.{}'.format(name))
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
model_name = filename[11:] model_name = filename[11:]
checkpoint = model_zoo.load_url(model_urls[model_name]) checkpoint = model_zoo.load_url(model_urls[model_name])
elif filename.startswith('open-mmlab://'):
model_name = filename[13:]
checkpoint = model_zoo.load_url(open_mmlab_model_urls[model_name])
elif filename.startswith(('http://', 'https://')): elif filename.startswith(('http://', 'https://')):
checkpoint = model_zoo.load_url(filename) checkpoint = model_zoo.load_url(filename)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment