Commit 37d8facf authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

Support default/external json for open-mmlab models (#230)

* support default/external json for open-mmlab models

* add local

* add more test

* add docs

* add docs

* update docs

* refactor

* add json in MANIFEST

* fixed json typo
parent c63e4d57
include mmcv/video/optflow_warp/*.hpp mmcv/video/optflow_warp/*.pyx include mmcv/video/optflow_warp/*.hpp mmcv/video/optflow_warp/*.pyx
include requirements.txt include requirements.txt
include mmcv/model_zoo/open_mmlab.json
...@@ -13,6 +13,7 @@ Contents ...@@ -13,6 +13,7 @@ Contents
utils.md utils.md
runner.md runner.md
cnn.md cnn.md
model_zoo.md
api.rst api.rst
......
## Model Zoo
Besides torchvision pre-trained models, we also provide pre-trained models of following CNN:
* VGG Caffe
* ResNet Caffe
* ResNeXt
* ResNet with Group Normalization
* ResNet with Group Normalization and Weight Standardization
* HRNetV2
* Res2Net
* RegNet
### Model URLs in JSON
The model zoo links in MMCV are managed by JSON files.
The json file consists of key-value pair of model name and its url or path.
An example json file could be like:
```json
{
"model_a": "https://example.com/models/model_a_9e5bac.pth",
"model_b": "pretrain/model_b_ab3ef2c.pth"
}
```
The default links of the pre-trained models hosted on Open-MMLab AWS could be found [here](../mmcv/model_zoo/open_mmlab.json).
You may override default links by putting `open-mmlab.json` under `MMCV_HOME`. If `MMCV_HOME` is not find in the environment, `~/.cache/mmcv` will be used by default. You may `export MMCV_HOME=/your/path` to use your own path.
The external json files will be merged into default one. If the same key presents in both external json and default json, the external one will be used.
### Load Checkpoint
The following types are supported for `filename` argument of `mmcv.load_checkpoint()`.
* filepath: The filepath of the checkpoint.
* `http://xxx` and `https://xxx`: The link to download the checkpoint. The `SHA256` postfix should be contained in the filename.
* `torchvison://xxx`: The model links in `torchvision.models`.Please refer to [torchvision](https://pytorch.org/docs/stable/torchvision/models.html) for details.
* `open-mmlab://xxx`: The model links or filepath provided in default and additional json files.
{
"vgg16_caffe": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/vgg16_caffe-292e1171.pth",
"detectron/resnet50_caffe": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth",
"detectron2/resnet50_caffe_bgr": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_msra-5891d200.pth",
"detectron/resnet101_caffe": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet101_caffe-3ad79236.pth",
"detectron2/resnet101_caffe_bgr": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet101_msra-6cc46731.pth",
"resnext50_32x4d": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth",
"resnext101_32x4d": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth",
"resnext101_64x4d": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth",
"contrib/resnet50_gn": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth",
"detectron/resnet50_gn": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_gn-9186a21c.pth",
"detectron/resnet101_gn": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet101_gn-cac0ab98.pth",
"jhu/resnet50_gn_ws": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth",
"jhu/resnet101_gn_ws": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth",
"jhu/resnext50_32x4d_gn_ws": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth",
"jhu/resnext101_32x4d_gn_ws": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth",
"jhu/resnext50_32x4d_gn": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth",
"jhu/resnext101_32x4d_gn": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth",
"msra/hrnetv2_w18": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth",
"msra/hrnetv2_w32": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth",
"msra/hrnetv2_w40": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth",
"bninception_caffe": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth",
"kin400/i3d_r50_f32s2_k400": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth",
"kin400/nl3d_r50_f32s2_k400": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth",
"res2net101_v1d_26w_4s": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth",
"regnetx_800mf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth",
"regnetx_1.6gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth",
"regnetx_3.2gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth",
"regnetx_4.0gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth",
"regnetx_6.4gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth",
"regnetx_8.0gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth",
"regnetx_12gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth"
}
...@@ -12,41 +12,24 @@ import torchvision ...@@ -12,41 +12,24 @@ import torchvision
from torch.utils import model_zoo from torch.utils import model_zoo
import mmcv import mmcv
from ..fileio import load as load_file
from ..utils import mkdir_or_exist
from .dist_utils import get_dist_info from .dist_utils import get_dist_info
open_mmlab_model_urls = { ENV_MMCV_HOME = 'MMCV_HOME'
'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501 ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
'resnet50_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_caffe-788b5fa3.pth', # noqa: E501 DEFAULT_CACHE_DIR = '~/.cache'
'resnet50_caffe_bgr': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_msra-5891d200.pth', # noqa: E501
'resnet101_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_caffe-3ad79236.pth', # noqa: E501
'resnet101_caffe_bgr': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet101_msra-6cc46731.pth', # noqa: E501 def _get_mmcv_home():
'resnext50_32x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50-32x4d-0ab1a123.pth', # noqa: E501 mmcv_home = os.path.expanduser(
'resnext101_32x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d-a5af3160.pth', # noqa: E501 os.getenv(
'resnext101_64x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth', # noqa: E501 ENV_MMCV_HOME,
'contrib/resnet50_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth', # noqa: E501 os.path.join(
'detectron/resnet50_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn-9186a21c.pth', # noqa: E501 os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
'detectron/resnet101_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_gn-cac0ab98.pth', # noqa: E501
'jhu/resnet50_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn_ws-15beedd8.pth', # noqa: E501 mkdir_or_exist(mmcv_home)
'jhu/resnet101_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth', # noqa: E501 return mmcv_home
'jhu/resnext50_32x4d_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth', # noqa: E501
'jhu/resnext101_32x4d_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth', # noqa: E501
'jhu/resnext50_32x4d_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth', # noqa: E501
'jhu/resnext101_32x4d_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth', # noqa: E501
'msra/hrnetv2_w18': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w18-00eb2006.pth', # noqa: E501
'msra/hrnetv2_w32': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth', # noqa: E501
'msra/hrnetv2_w40': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w40-ed0b031c.pth', # noqa: E501
'bninception_caffe': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth', # noqa: E501
'kin400/i3d_r50_f32s2_k400': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth', # noqa: E501
'kin400/nl3d_r50_f32s2_k400': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth', # noqa: E501
'regnetx_800mf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth', # noqa: E501
'regnetx_1.6gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth', # noqa: E501
'regnetx_3.2gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth', # noqa: E501
'regnetx_4.0gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth', # noqa: E501
'regnetx_6.4gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth', # noqa: E501
'regnetx_8.0gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth', # noqa: E501
'regnetx_12gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth', # noqa: E501
'res2net101_v1d_26w_4s': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth', # noqa: E501
} # yapf: disable
def load_state_dict(module, state_dict, strict=False, logger=None): def load_state_dict(module, state_dict, strict=False, logger=None):
...@@ -113,17 +96,17 @@ def load_state_dict(module, state_dict, strict=False, logger=None): ...@@ -113,17 +96,17 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
print(err_msg) print(err_msg)
def load_url_dist(url): def load_url_dist(url, model_dir=None):
""" In distributed setting, this function only download checkpoint at """ In distributed setting, this function only download checkpoint at
local rank 0 """ local rank 0 """
rank, world_size = get_dist_info() rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank)) rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0: if rank == 0:
checkpoint = model_zoo.load_url(url) checkpoint = model_zoo.load_url(url, model_dir=model_dir)
if world_size > 1: if world_size > 1:
torch.distributed.barrier() torch.distributed.barrier()
if rank > 0: if rank > 0:
checkpoint = model_zoo.load_url(url) checkpoint = model_zoo.load_url(url, model_dir=model_dir)
return checkpoint return checkpoint
...@@ -139,11 +122,27 @@ def get_torchvision_models(): ...@@ -139,11 +122,27 @@ def get_torchvision_models():
return model_urls return model_urls
def get_external_models():
mmcv_home = _get_mmcv_home()
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
default_urls = load_file(default_json_path)
assert isinstance(default_urls, dict)
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
if osp.exists(external_json_path):
external_urls = load_file(external_json_path)
assert isinstance(external_urls, dict)
default_urls.update(external_urls)
return default_urls
def _load_checkpoint(filename, map_location=None): def _load_checkpoint(filename, map_location=None):
"""Load checkpoint from somewhere (modelzoo, file, url). """Load checkpoint from somewhere (modelzoo, file, url).
Args: Args:
filename (str): Either a filepath or URI. filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str | None): Same as :func:`torch.load`. Default: None. map_location (str | None): Same as :func:`torch.load`. Default: None.
Returns: Returns:
...@@ -162,8 +161,17 @@ def _load_checkpoint(filename, map_location=None): ...@@ -162,8 +161,17 @@ def _load_checkpoint(filename, map_location=None):
model_name = filename[14:] model_name = filename[14:]
checkpoint = load_url_dist(model_urls[model_name]) checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith('open-mmlab://'): elif filename.startswith('open-mmlab://'):
model_urls = get_external_models()
model_name = filename[13:] model_name = filename[13:]
checkpoint = load_url_dist(open_mmlab_model_urls[model_name]) model_url = model_urls[model_name]
# check if is url
if model_url.startswith(('http://', 'https://')):
checkpoint = load_url_dist(model_url)
else:
filename = osp.join(_get_mmcv_home(), model_url)
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location)
elif filename.startswith(('http://', 'https://')): elif filename.startswith(('http://', 'https://')):
checkpoint = load_url_dist(filename) checkpoint = load_url_dist(filename)
else: else:
...@@ -182,7 +190,9 @@ def load_checkpoint(model, ...@@ -182,7 +190,9 @@ def load_checkpoint(model,
Args: Args:
model (Module): Module to load checkpoint. model (Module): Module to load checkpoint.
filename (str): Either a filepath or URL or modelzoo://xxxxxxx. filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str): Same as :func:`torch.load`. map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and strict (bool): Whether to allow different params for the model and
checkpoint. checkpoint.
......
{
"test": "test.pth",
"val": "val.pth",
"train_empty": "train.pth"
}
\ No newline at end of file
{
"train": "https://localhost/train.pth",
"test": "https://localhost/test.pth"
}
\ No newline at end of file
# Copyright (c) Open-MMLab. All rights reserved.
import os
import os.path as osp
from unittest.mock import patch
import pytest
import mmcv
from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME,
ENV_XDG_CACHE_HOME, _get_mmcv_home,
_load_checkpoint, get_external_models)
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
def test_set_mmcv_home():
os.environ.pop(ENV_MMCV_HOME, None)
mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/')
os.environ[ENV_MMCV_HOME] = mmcv_home
assert _get_mmcv_home() == mmcv_home
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
def test_default_mmcv_home():
os.environ.pop(ENV_MMCV_HOME, None)
os.environ.pop(ENV_XDG_CACHE_HOME, None)
assert _get_mmcv_home() == os.path.expanduser(
os.path.join(DEFAULT_CACHE_DIR, 'mmcv'))
model_urls = get_external_models()
assert model_urls == mmcv.load(
osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json'))
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
def test_get_external_models():
os.environ.pop(ENV_MMCV_HOME, None)
mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/')
os.environ[ENV_MMCV_HOME] = mmcv_home
ext_urls = get_external_models()
assert ext_urls == {
'train': 'https://localhost/train.pth',
'test': 'test.pth',
'val': 'val.pth',
'train_empty': 'train.pth'
}
def load_url_dist(url):
return 'url:' + url
def load(filepath, map_location=None):
return 'local:' + filepath
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
@patch('mmcv.runner.checkpoint.load_url_dist', load_url_dist)
@patch('torch.load', load)
def test_load_external_url():
# test modelzoo://
url = _load_checkpoint('modelzoo://resnet50')
assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \
'.pth'
# test torchvision://
url = _load_checkpoint('torchvision://resnet50')
assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \
'.pth'
# test open-mmlab:// with default MMCV_HOME
os.environ.pop(ENV_MMCV_HOME, None)
os.environ.pop(ENV_XDG_CACHE_HOME, None)
url = _load_checkpoint('open-mmlab://train')
assert url == 'url:https://localhost/train.pth'
# test open-mmlab:// with user-defined MMCV_HOME
os.environ.pop(ENV_MMCV_HOME, None)
mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home')
os.environ[ENV_MMCV_HOME] = mmcv_home
url = _load_checkpoint('open-mmlab://train')
assert url == 'url:https://localhost/train.pth'
with pytest.raises(IOError, match='train.pth is not a checkpoint ' 'file'):
_load_checkpoint('open-mmlab://train_empty')
url = _load_checkpoint('open-mmlab://test')
assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'
url = _load_checkpoint('open-mmlab://val')
assert url == f'local:{osp.join(_get_mmcv_home(), "val.pth")}'
# test http:// https://
url = _load_checkpoint('http://localhost/train.pth')
assert url == 'url:http://localhost/train.pth'
# test local file
with pytest.raises(IOError, match='train.pth is not a checkpoint ' 'file'):
_load_checkpoint('train.pth')
url = _load_checkpoint(osp.join(_get_mmcv_home(), 'test.pth'))
assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment