Unverified Commit cfd337bb authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

Support load with mmcls:// (#511)

* [Feature] Support load models from mmcls

* [Feature] Support load with mmcls://

* hard-code load mmcls

* fixed wrong commit

* add json

* remove cifar
parent 06556c84
include mmcv/video/optflow_warp/*.hpp mmcv/video/optflow_warp/*.pyx include mmcv/video/optflow_warp/*.hpp mmcv/video/optflow_warp/*.pyx
include requirements.txt include requirements.txt
include mmcv/model_zoo/open_mmlab.json mmcv/model_zoo/deprecated.json include mmcv/model_zoo/open_mmlab.json mmcv/model_zoo/deprecated.json mmcv/model_zoo/mmcls.json
include mmcv/ops/csrc/*.cuh mmcv/ops/csrc/*.hpp include mmcv/ops/csrc/*.cuh mmcv/ops/csrc/*.hpp
include mmcv/ops/csrc/pytorch/*.cu mmcv/ops/csrc/pytorch/*.cpp include mmcv/ops/csrc/pytorch/*.cu mmcv/ops/csrc/pytorch/*.cpp
include mmcv/ops/csrc/parrots/*.cu mmcv/ops/csrc/parrots/*.cpp include mmcv/ops/csrc/parrots/*.cu mmcv/ops/csrc/parrots/*.cpp
{
"resnet50_v1d": "https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnetv1d50_batch256_20200708-1ad0ce94.pth",
"resnet101_v1d": "https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnetv1d101_batch256_20200708-9cb302ef.pth",
"resnet152_v1d": "https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnetv1d152_batch256_20200708-e79cb6a2.pth",
"resnext50": "https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnext50_32x4d_batch256_20200708-c07adbb7.pth",
"resnext101": "https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnext101_32x8d_batch256_20200708-1ec34aa7.pth",
"resnext152": "https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnext152_32x4d_batch256_20200708-aab5034c.pth",
"se-resnet50": "https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/se-resnet50_batch256_20200804-ae206104.pth",
"se-resnet101": "https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/se-resnet101_batch256_20200804-ba5b51d4.pth",
"shufflenet_v1": "https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/shufflenet_v1_batch1024_20200804-5d6cec73.pth",
"shufflenet_v2": "https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/shufflenet_v2_batch1024_20200812-5bf4721e.pth",
"mobilenet_v2": "https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/mobilenet_v2_batch256_20200708-3b2dc3af.pth"
}
...@@ -142,6 +142,13 @@ def get_external_models(): ...@@ -142,6 +142,13 @@ def get_external_models():
return default_urls return default_urls
def get_mmcls_models():
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
mmcls_urls = load_file(mmcls_json_path)
return mmcls_urls
def get_deprecated_model_names(): def get_deprecated_model_names():
deprecate_json_path = osp.join(mmcv.__path__[0], deprecate_json_path = osp.join(mmcv.__path__[0],
'model_zoo/deprecated.json') 'model_zoo/deprecated.json')
...@@ -151,6 +158,17 @@ def get_deprecated_model_names(): ...@@ -151,6 +158,17 @@ def get_deprecated_model_names():
return deprecate_urls return deprecate_urls
def _process_mmcls_checkpoint(checkpoint):
state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('backbone.'):
new_state_dict[k[9:]] = v
new_checkpoint = dict(state_dict=new_state_dict)
return new_checkpoint
def _load_checkpoint(filename, map_location=None): def _load_checkpoint(filename, map_location=None):
"""Load checkpoint from somewhere (modelzoo, file, url). """Load checkpoint from somewhere (modelzoo, file, url).
...@@ -192,6 +210,11 @@ def _load_checkpoint(filename, map_location=None): ...@@ -192,6 +210,11 @@ def _load_checkpoint(filename, map_location=None):
if not osp.isfile(filename): if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file') raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location) checkpoint = torch.load(filename, map_location=map_location)
elif filename.startswith('mmcls://'):
model_urls = get_mmcls_models()
model_name = filename[8:]
checkpoint = load_url_dist(model_urls[model_name])
checkpoint = _process_mmcls_checkpoint(checkpoint)
elif filename.startswith(('http://', 'https://')): elif filename.startswith(('http://', 'https://')):
checkpoint = load_url_dist(filename) checkpoint = load_url_dist(filename)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment