Unverified Commit 09b7d6c7 authored by Wang Xinjiang's avatar Wang Xinjiang Committed by GitHub
Browse files

Import modules from a string list (#514)

* Custom imports

* Resolve comments

* Add unittest

* Add unittest

* Rename custom_imports to import_modules_from_strings

* Move import_modules_from_strings ito misc.py and allow failed imports

* small change

* small change

* change mmcv.runner to os.path
parent 270e470e
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .config import Config, ConfigDict, DictAction from .config import Config, ConfigDict, DictAction
from .misc import (check_prerequisites, concat_list, deprecated_api_warning, from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
is_list_of, is_seq_of, is_str, is_tuple_of, iter_cast, import_modules_from_strings, is_list_of, is_seq_of, is_str,
list_cast, requires_executable, requires_package, is_tuple_of, iter_cast, list_cast, requires_executable,
slice_list, tuple_cast) requires_package, slice_list, tuple_cast)
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist, from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
scandir, symlink) scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress, from .progressbar import (ProgressBar, track_iter_progress,
...@@ -23,7 +23,7 @@ except ImportError: ...@@ -23,7 +23,7 @@ except ImportError:
'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar', 'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar',
'track_progress', 'track_iter_progress', 'track_parallel_progress', 'track_progress', 'track_iter_progress', 'track_parallel_progress',
'Timer', 'TimerError', 'check_time', 'deprecated_api_warning', 'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
'get_git_hash' 'get_git_hash', 'import_modules_from_strings'
] ]
else: else:
from .env import collect_env from .env import collect_env
...@@ -47,5 +47,6 @@ else: ...@@ -47,5 +47,6 @@ else:
'_AvgPoolNd', '_BatchNorm', '_ConvNd', '_ConvTransposeMixin', '_AvgPoolNd', '_BatchNorm', '_ConvNd', '_ConvTransposeMixin',
'_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension', '_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension',
'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader', 'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader',
'TORCH_VERSION', 'deprecated_api_warning', 'get_git_hash' 'TORCH_VERSION', 'deprecated_api_warning', 'get_git_hash',
'import_modules_from_strings'
] ]
...@@ -16,6 +16,54 @@ def is_str(x): ...@@ -16,6 +16,54 @@ def is_str(x):
return isinstance(x, str) return isinstance(x, str)
def import_modules_from_strings(imports, allow_failed_imports=False):
"""Import modules from the given list of strings.
Args:
imports (list | str | None): The given module names to be imported.
allow_failed_imports (bool): If True, the failed imports will return
None. Otherwise, an ImportError is raise. Default: False.
Returns:
list[module] | module | None: The imported modules.
Examples:
>>> osp, sys = import_modules_from_strings(
... ['os.path', 'sys'])
>>> import os.path as osp_
>>> import sys as sys_
>>> assert osp == osp_
>>> assert sys == sys_
"""
if not imports:
return
single_import = False
if isinstance(imports, str):
single_import = True
imports = [imports]
if not isinstance(imports, list):
raise TypeError(
f'custom_imports must be a list but got type {type(imports)}')
imported = []
for imp in imports:
if not isinstance(imp, str):
raise TypeError(
f'{imp} is of type {type(imp)} and cannot be imported.')
try:
imported_tmp = import_module(imp)
except ImportError:
if allow_failed_imports:
warnings.warn(f'{imp} failed to import and is ignored.',
UserWarning)
imported_tmp = None
else:
raise ImportError
imported.append(imported_tmp)
if single_import:
imported = imported[0]
return imported
def iter_cast(inputs, dst_type, return_type=None): def iter_cast(inputs, dst_type, return_type=None):
"""Cast elements of an iterable object into some type. """Cast elements of an iterable object into some type.
......
...@@ -100,3 +100,37 @@ def test_requires_executable(capsys): ...@@ -100,3 +100,37 @@ def test_requires_executable(capsys):
' please install them first.\n') ' please install them first.\n')
assert func_c() == 1 assert func_c() == 1
def test_import_modules_from_strings():
# multiple imports
import os.path as osp_
import sys as sys_
osp, sys = mmcv.import_modules_from_strings(['os.path', 'sys'])
assert osp == osp_
assert sys == sys_
# single imports
osp = mmcv.import_modules_from_strings('os.path')
assert osp == osp_
# No imports
assert mmcv.import_modules_from_strings(None) is None
assert mmcv.import_modules_from_strings([]) is None
assert mmcv.import_modules_from_strings('') is None
# Unsupported types
with pytest.raises(TypeError):
mmcv.import_modules_from_strings(1)
with pytest.raises(TypeError):
mmcv.import_modules_from_strings([1])
# Failed imports
with pytest.raises(ImportError):
mmcv.import_modules_from_strings('_not_implemented_module')
with pytest.warns(UserWarning):
imported = mmcv.import_modules_from_strings(
'_not_implemented_module', allow_failed_imports=True)
assert imported is None
with pytest.warns(UserWarning):
imported = mmcv.import_modules_from_strings(
['os.path', '_not_implemented'], allow_failed_imports=True)
assert imported[0] == osp
assert imported[1] is None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment