"scripts/deprecated/test_httpserver_reuse.py" did not exist on "9acc6e350475a64207a6702a579850c93ab27b43"
Unverified Commit 3b9b9054 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[Enhancement] Add support of TorchVision's Model Registration API on MMCV (#2246)

* Add support of TorchVision's Model Registration API.

* fix linter

* formatting with yapf
parent 25f533b9
......@@ -137,17 +137,28 @@ def get_torchvision_models():
json_path = osp.join(mmcv.__path__[0],
'model_zoo/torchvision_0.12.json')
model_urls = mmcv.load(json_path)
for cls_name, cls in torchvision.models.__dict__.items():
if digit_version(torchvision.__version__) < digit_version('0.14.0a0'):
weights_list = [
cls for cls_name, cls in torchvision.models.__dict__.items()
if cls_name.endswith('_Weights')
]
else:
weights_list = [
torchvision.models.get_model_weights(model)
for model in torchvision.models.list_models(torchvision.models)
]
for cls in weights_list:
# The name of torchvision model weights classes ends with
# `_Weights` such as `ResNet18_Weights`. However, some model weight
# classes, such as `MNASNet0_75_Weights` does not have any urls in
# torchvision 0.13.0 and cannot be iterated. Here we simply check
# `DEFAULT` attribute to ensure the class is not empty.
if (not cls_name.endswith('_Weights')
or not hasattr(cls, 'DEFAULT')):
if not hasattr(cls, 'DEFAULT'):
continue
# Since `cls.DEFAULT` can not be accessed by iterating cls, we set
# default urls explicitly.
cls_name = cls.__name__
cls_key = cls_name.replace('_Weights', '').lower()
model_urls[f'{cls_key}.default'] = cls.DEFAULT.url
for weight_enum in cls:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment