"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "49b959b5408b97274e2ee423059d9239445aea26"
Unverified Commit 3b9b9054 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[Enhancement] Add support of TorchVision's Model Registration API on MMCV (#2246)

* Add support of TorchVision's Model Registration API.

* fix linter

* formatting with yapf
parent 25f533b9
...@@ -137,17 +137,28 @@ def get_torchvision_models(): ...@@ -137,17 +137,28 @@ def get_torchvision_models():
json_path = osp.join(mmcv.__path__[0], json_path = osp.join(mmcv.__path__[0],
'model_zoo/torchvision_0.12.json') 'model_zoo/torchvision_0.12.json')
model_urls = mmcv.load(json_path) model_urls = mmcv.load(json_path)
for cls_name, cls in torchvision.models.__dict__.items(): if digit_version(torchvision.__version__) < digit_version('0.14.0a0'):
weights_list = [
cls for cls_name, cls in torchvision.models.__dict__.items()
if cls_name.endswith('_Weights')
]
else:
weights_list = [
torchvision.models.get_model_weights(model)
for model in torchvision.models.list_models(torchvision.models)
]
for cls in weights_list:
# The name of torchvision model weights classes ends with # The name of torchvision model weights classes ends with
# `_Weights` such as `ResNet18_Weights`. However, some model weight # `_Weights` such as `ResNet18_Weights`. However, some model weight
# classes, such as `MNASNet0_75_Weights` does not have any urls in # classes, such as `MNASNet0_75_Weights` does not have any urls in
# torchvision 0.13.0 and cannot be iterated. Here we simply check # torchvision 0.13.0 and cannot be iterated. Here we simply check
# `DEFAULT` attribute to ensure the class is not empty. # `DEFAULT` attribute to ensure the class is not empty.
if (not cls_name.endswith('_Weights') if not hasattr(cls, 'DEFAULT'):
or not hasattr(cls, 'DEFAULT')):
continue continue
# Since `cls.DEFAULT` can not be accessed by iterating cls, we set # Since `cls.DEFAULT` can not be accessed by iterating cls, we set
# default urls explicitly. # default urls explicitly.
cls_name = cls.__name__
cls_key = cls_name.replace('_Weights', '').lower() cls_key = cls_name.replace('_Weights', '').lower()
model_urls[f'{cls_key}.default'] = cls.DEFAULT.url model_urls[f'{cls_key}.default'] = cls.DEFAULT.url
for weight_enum in cls: for weight_enum in cls:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment