Unverified Commit 5947178e authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

Remove many functions in utils and migrate them to mmengine (#2217)

* Remove runner, parallel, engine and device

* fix format

* remove outdated docs

* migrate many functions to mmengine

* remove sync_bn.py
parent 9185eee8
......@@ -6,11 +6,11 @@ import torch
import torch.nn as nn
from mmengine import print_log
from mmengine.registry import MODELS
from mmengine.utils import deprecated_api_warning
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single
from mmcv.utils import deprecated_api_warning
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
......
......@@ -3,16 +3,16 @@ import math
import warnings
from typing import Optional, no_type_check
import mmengine
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmengine.model.utils import constant_init, xavier_init
from mmengine.registry import MODELS
from mmengine.utils import deprecated_api_warning
from torch.autograd.function import Function, once_differentiable
import mmcv
from mmcv import deprecated_api_warning
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
......@@ -193,7 +193,7 @@ class MultiScaleDeformableAttention(BaseModule):
dropout: float = 0.1,
batch_first: bool = False,
norm_cfg: Optional[dict] = None,
init_cfg: Optional[mmcv.ConfigDict] = None):
init_cfg: Optional[mmengine.ConfigDict] = None):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
......
......@@ -3,9 +3,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from mmengine.utils import deprecated_api_warning
from torch import Tensor
from mmcv.utils import deprecated_api_warning
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
......
......@@ -3,9 +3,10 @@ from typing import Any, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmengine.utils import is_tuple_of
from torch.autograd import Function
from ..utils import ext_loader, is_tuple_of
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['riroi_align_rotated_forward', 'riroi_align_rotated_backward'])
......
......@@ -3,11 +3,12 @@ from typing import Any
import torch
import torch.nn as nn
from mmengine.utils import deprecated_api_warning
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from ..utils import deprecated_api_warning, ext_loader
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext',
['roi_align_forward', 'roi_align_backward'])
......
......@@ -3,10 +3,11 @@ from typing import Any, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmengine.utils import deprecated_api_warning
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from ..utils import deprecated_api_warning, ext_loader
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['roi_align_rotated_forward', 'roi_align_rotated_backward'])
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Tuple, Union
import mmengine
import torch
from torch import nn as nn
from torch.autograd import Function
import mmcv
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
......@@ -86,7 +86,7 @@ class RoIAwarePool3dFunction(Function):
out_x = out_y = out_z = out_size
else:
assert len(out_size) == 3
assert mmcv.is_tuple_of(out_size, int)
assert mmengine.is_tuple_of(out_size, int)
out_x, out_y, out_z = out_size
num_rois = rois.shape[0]
......
......@@ -4,10 +4,10 @@ import torch.nn as nn
import torch.nn.functional as F
from mmengine.model.utils import constant_init
from mmengine.registry import MODELS
from mmengine.utils import TORCH_VERSION, digit_version
from mmcv.cnn import ConvAWS2d
from mmcv.ops.deform_conv import deform_conv2d
from mmcv.utils import TORCH_VERSION, digit_version
@MODELS.register_module(name='SAC')
......
......@@ -98,10 +98,10 @@
from typing import Any, List, Tuple, Union
import torch
from mmengine.utils import to_2tuple
from torch.autograd import Function
from torch.nn import functional as F
from mmcv.utils import to_2tuple
from ..utils import ext_loader
upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d'])
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Union
import mmengine
import numpy as np
import torch
import mmcv
from .base import BaseTransform
from .builder import TRANSFORMS
......@@ -29,7 +29,7 @@ def to_tensor(
return data
elif isinstance(data, np.ndarray):
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not mmcv.is_str(data):
elif isinstance(data, Sequence) and not mmengine.is_str(data):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
......
......@@ -3,6 +3,7 @@ import random
import warnings
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import mmengine
import numpy as np
import mmcv
......@@ -797,7 +798,7 @@ class MultiScaleFlipAug(BaseTransform):
if scales is not None:
self.scales = scales if isinstance(scales, list) else [scales]
self.scale_key = 'scale'
assert mmcv.is_list_of(self.scales, tuple)
assert mmengine.is_list_of(self.scales, tuple)
else:
# if ``scales`` and ``scale_factor`` both be ``None``
if scale_factor is None:
......@@ -812,7 +813,7 @@ class MultiScaleFlipAug(BaseTransform):
self.allow_flip = allow_flip
self.flip_direction = flip_direction if isinstance(
flip_direction, list) else [flip_direction]
assert mmcv.is_list_of(self.flip_direction, str)
assert mmengine.is_list_of(self.flip_direction, str)
if not self.allow_flip and self.flip_direction != ['horizontal']:
warnings.warn(
'flip_direction has no effect when flip is set to False')
......@@ -934,7 +935,7 @@ class RandomChoiceResize(BaseTransform):
self.scales = scales
else:
self.scales = [scales]
assert mmcv.is_list_of(self.scales, tuple)
assert mmengine.is_list_of(self.scales, tuple)
self.resize_cfg = dict(type=resize_type, **resize_kwargs)
# create a empty Resize object
......@@ -950,7 +951,7 @@ class RandomChoiceResize(BaseTransform):
``scale_idx`` is the selected index in the given candidates.
"""
assert mmcv.is_list_of(self.scales, tuple)
assert mmengine.is_list_of(self.scales, tuple)
scale_idx = np.random.randint(len(self.scales))
scale = self.scales[scale_idx]
return scale, scale_idx
......@@ -1033,7 +1034,7 @@ class RandomFlip(BaseTransform):
direction: Union[str,
Sequence[Optional[str]]] = 'horizontal') -> None:
if isinstance(prob, list):
assert mmcv.is_list_of(prob, float)
assert mmengine.is_list_of(prob, float)
assert 0 <= sum(prob) <= 1
elif isinstance(prob, float):
assert 0 <= prob <= 1
......@@ -1046,7 +1047,7 @@ class RandomFlip(BaseTransform):
if isinstance(direction, str):
assert direction in valid_directions
elif isinstance(direction, list):
assert mmcv.is_list_of(direction, str)
assert mmengine.is_list_of(direction, str)
assert set(direction).issubset(set(valid_directions))
else:
raise ValueError(f'direction must be either str or list of str, \
......@@ -1308,7 +1309,7 @@ class RandomResize(BaseTransform):
tuple: The targeted scale of the image to be resized.
"""
assert mmcv.is_list_of(scales, tuple) and len(scales) == 2
assert mmengine.is_list_of(scales, tuple) and len(scales) == 2
scale_0 = [scales[0][0], scales[1][0]]
scale_1 = [scales[0][1], scales[1][1]]
edge_0 = np.random.randint(min(scale_0), max(scale_0) + 1)
......@@ -1350,12 +1351,12 @@ class RandomResize(BaseTransform):
tuple: The targeted scale of the image to be resized.
"""
if mmcv.is_tuple_of(self.scale, int):
if mmengine.is_tuple_of(self.scale, int):
assert self.ratio_range is not None and len(self.ratio_range) == 2
scale = self._random_sample_ratio(
self.scale, # type: ignore
self.ratio_range)
elif mmcv.is_seq_of(self.scale, tuple):
elif mmengine.is_seq_of(self.scale, tuple):
scale = self._random_sample(self.scale) # type: ignore
else:
raise NotImplementedError('Do not support sampling function '
......
......@@ -2,9 +2,9 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import mmengine
import numpy as np
import mmcv
from .base import BaseTransform
from .builder import TRANSFORMS
from .utils import cache_random_params, cache_randomness
......@@ -569,7 +569,7 @@ class RandomChoice(BaseTransform):
super().__init__()
if prob is not None:
assert mmcv.is_seq_of(prob, float)
assert mmengine.is_seq_of(prob, float)
assert len(transforms) == len(prob), \
'``transforms`` and ``prob`` must have same lengths. ' \
f'Got {len(transforms)} vs {len(prob)}.'
......
# flake8: noqa
# Copyright (c) OpenMMLab. All rights reserved.
from .config import Config, ConfigDict, DictAction
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
has_method, import_modules_from_strings, is_list_of,
is_method_overridden, is_seq_of, is_str, is_tuple_of,
iter_cast, list_cast, requires_executable, requires_package,
slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
to_ntuple, tuple_cast)
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress,
track_parallel_progress, track_progress)
from .testing import (assert_attrs_equal, assert_dict_contains_subset,
assert_dict_has_keys, assert_is_norm_layer,
assert_keys_equal, assert_params_all_zeros,
check_python_script)
from .timer import Timer, TimerError, check_time
from .version_utils import digit_version, get_git_hash
from .device_type import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
from .env import collect_env
from .parrots_jit import jit, skip_no_elena
try:
import torch
except ImportError:
__all__ = [
'Config', 'ConfigDict', 'DictAction', 'is_str', 'iter_cast',
'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of', 'is_tuple_of',
'slice_list', 'concat_list', 'check_prerequisites', 'requires_package',
'requires_executable', 'is_filepath', 'fopen', 'check_file_exist',
'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar',
'track_progress', 'track_iter_progress', 'track_parallel_progress',
'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
'digit_version', 'get_git_hash', 'import_modules_from_strings',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script',
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
'is_method_overridden', 'has_method'
]
else:
from .device_type import (IS_IPU_AVAILABLE, IS_MLU_AVAILABLE,
IS_MPS_AVAILABLE)
from .env import collect_env
from .logging import get_logger, print_log
from .parrots_jit import jit, skip_no_elena
# yapf: disable
from .parrots_wrapper import (IS_CUDA_AVAILABLE, TORCH_VERSION,
BuildExtension, CppExtension, CUDAExtension,
DataLoader, PoolDataLoader, SyncBatchNorm,
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
_AvgPoolNd, _BatchNorm, _ConvNd,
_ConvTransposeMixin, _get_cuda_home,
_InstanceNorm, _MaxPoolNd, get_build_config,
is_rocm_pytorch)
# yapf: enable
from .registry import Registry, build_from_cfg
from .seed import worker_init_fn
from .torch_ops import torch_meshgrid
from .trace import is_jit_tracing
__all__ = [
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
'is_seq_of', 'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list',
'check_prerequisites', 'requires_package', 'requires_executable',
'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist',
'symlink', 'scandir', 'ProgressBar', 'track_progress',
'track_iter_progress', 'track_parallel_progress', 'Registry',
'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'SyncBatchNorm',
'_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm',
'_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd',
'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension',
'DataLoader', 'PoolDataLoader', 'TORCH_VERSION',
'deprecated_api_warning', 'digit_version', 'get_git_hash',
'import_modules_from_strings', 'jit', 'skip_no_elena',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
'assert_params_all_zeros', 'check_python_script',
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
'_get_cuda_home', 'has_method', 'IS_CUDA_AVAILABLE', 'worker_init_fn',
'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE', 'IS_MPS_AVAILABLE',
'torch_meshgrid'
]
__all__ = [
'IS_MLU_AVAILABLE', 'IS_MPS_AVAILABLE', 'IS_CUDA_AVAILABLE', 'collect_env',
'jit', 'skip_no_elena'
]
# Copyright (c) OpenMMLab. All rights reserved.
import ast
import copy
import os
import os.path as osp
import platform
import shutil
import sys
import tempfile
import types
import uuid
import warnings
from argparse import Action, ArgumentParser
from collections import abc
from importlib import import_module
from pathlib import Path
import mmengine
from addict import Dict
from yapf.yapflib.yapf_api import FormatCode
from .misc import import_modules_from_strings
from .path import check_file_exist
if platform.system() == 'Windows':
import regex as re # type: ignore
else:
import re # type: ignore
BASE_KEY = '_base_'
DELETE_KEY = '_delete_'
DEPRECATION_KEY = '_deprecation_'
RESERVED_KEYS = ['filename', 'text', 'pretty_text']
class ConfigDict(Dict):
def __missing__(self, name):
raise KeyError(name)
def __getattr__(self, name):
try:
value = super().__getattr__(name)
except KeyError:
ex = AttributeError(f"'{self.__class__.__name__}' object has no "
f"attribute '{name}'")
except Exception as e:
ex = e
else:
return value
raise ex
def add_args(parser, cfg, prefix=''):
for k, v in cfg.items():
if isinstance(v, str):
parser.add_argument('--' + prefix + k)
elif isinstance(v, int):
parser.add_argument('--' + prefix + k, type=int)
elif isinstance(v, float):
parser.add_argument('--' + prefix + k, type=float)
elif isinstance(v, bool):
parser.add_argument('--' + prefix + k, action='store_true')
elif isinstance(v, dict):
add_args(parser, v, prefix + k + '.')
elif isinstance(v, abc.Iterable):
parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
else:
print(f'cannot parse key {prefix + k} of type {type(v)}')
return parser
class Config:
"""A facility for config and config files.
It supports common file formats as configs: python/json/yaml. The interface
is the same as a dict object and also allows access config values as
attributes.
Example:
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
>>> cfg.a
1
>>> cfg.b
{'b1': [0, 1]}
>>> cfg.b.b1
[0, 1]
>>> cfg = Config.fromfile('tests/data/config/a.py')
>>> cfg.filename
"/home/kchen/projects/mmcv/tests/data/config/a.py"
>>> cfg.item4
'test'
>>> cfg
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
"""
@staticmethod
def _validate_py_syntax(filename):
with open(filename, encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
content = f.read()
try:
ast.parse(content)
except SyntaxError as e:
raise SyntaxError('There are syntax errors in config '
f'file {filename}: {e}')
@staticmethod
def _substitute_predefined_vars(filename, temp_config_name):
file_dirname = osp.dirname(filename)
file_basename = osp.basename(filename)
file_basename_no_extension = osp.splitext(file_basename)[0]
file_extname = osp.splitext(filename)[1]
support_templates = dict(
fileDirname=file_dirname,
fileBasename=file_basename,
fileBasenameNoExtension=file_basename_no_extension,
fileExtname=file_extname)
with open(filename, encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
config_file = f.read()
for key, value in support_templates.items():
regexp = r'\{\{\s*' + str(key) + r'\s*\}\}'
value = value.replace('\\', '/')
config_file = re.sub(regexp, value, config_file)
with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
tmp_config_file.write(config_file)
@staticmethod
def _pre_substitute_base_vars(filename, temp_config_name):
"""Substitute base variable placehoders to string, so that parsing
would work."""
with open(filename, encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
config_file = f.read()
base_var_dict = {}
regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}'
base_vars = set(re.findall(regexp, config_file))
for base_var in base_vars:
randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}'
base_var_dict[randstr] = base_var
regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}'
config_file = re.sub(regexp, f'"{randstr}"', config_file)
with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
tmp_config_file.write(config_file)
return base_var_dict
@staticmethod
def _substitute_base_vars(cfg, base_var_dict, base_cfg):
"""Substitute variable strings to their actual values."""
cfg = copy.deepcopy(cfg)
if isinstance(cfg, dict):
for k, v in cfg.items():
if isinstance(v, str) and v in base_var_dict:
new_v = base_cfg
for new_k in base_var_dict[v].split('.'):
new_v = new_v[new_k]
cfg[k] = new_v
elif isinstance(v, (list, tuple, dict)):
cfg[k] = Config._substitute_base_vars(
v, base_var_dict, base_cfg)
elif isinstance(cfg, tuple):
cfg = tuple(
Config._substitute_base_vars(c, base_var_dict, base_cfg)
for c in cfg)
elif isinstance(cfg, list):
cfg = [
Config._substitute_base_vars(c, base_var_dict, base_cfg)
for c in cfg
]
elif isinstance(cfg, str) and cfg in base_var_dict:
new_v = base_cfg
for new_k in base_var_dict[cfg].split('.'):
new_v = new_v[new_k]
cfg = new_v
return cfg
@staticmethod
def _file2dict(filename, use_predefined_variables=True):
filename = osp.abspath(osp.expanduser(filename))
check_file_exist(filename)
fileExtname = osp.splitext(filename)[1]
if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
raise OSError('Only py/yml/yaml/json type are supported now!')
with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(
dir=temp_config_dir, suffix=fileExtname)
if platform.system() == 'Windows':
temp_config_file.close()
temp_config_name = osp.basename(temp_config_file.name)
# Substitute predefined variables
if use_predefined_variables:
Config._substitute_predefined_vars(filename,
temp_config_file.name)
else:
shutil.copyfile(filename, temp_config_file.name)
# Substitute base variables from placeholders to strings
base_var_dict = Config._pre_substitute_base_vars(
temp_config_file.name, temp_config_file.name)
if filename.endswith('.py'):
temp_module_name = osp.splitext(temp_config_name)[0]
sys.path.insert(0, temp_config_dir)
Config._validate_py_syntax(filename)
mod = import_module(temp_module_name)
sys.path.pop(0)
cfg_dict = {
name: value
for name, value in mod.__dict__.items()
if not name.startswith('__')
and not isinstance(value, types.ModuleType)
and not isinstance(value, types.FunctionType)
}
# delete imported module
del sys.modules[temp_module_name]
elif filename.endswith(('.yml', '.yaml', '.json')):
cfg_dict = mmengine.load(temp_config_file.name)
# close temp file
temp_config_file.close()
# check deprecation information
if DEPRECATION_KEY in cfg_dict:
deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
warning_msg = f'The config file {filename} will be deprecated ' \
'in the future.'
if 'expected' in deprecation_info:
warning_msg += f' Please use {deprecation_info["expected"]} ' \
'instead.'
if 'reference' in deprecation_info:
warning_msg += ' More information can be found at ' \
f'{deprecation_info["reference"]}'
warnings.warn(warning_msg, DeprecationWarning)
cfg_text = filename + '\n'
with open(filename, encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
cfg_text += f.read()
if BASE_KEY in cfg_dict:
cfg_dir = osp.dirname(filename)
base_filename = cfg_dict.pop(BASE_KEY)
base_filename = base_filename if isinstance(
base_filename, list) else [base_filename]
cfg_dict_list = list()
cfg_text_list = list()
for f in base_filename:
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
cfg_dict_list.append(_cfg_dict)
cfg_text_list.append(_cfg_text)
base_cfg_dict = dict()
for c in cfg_dict_list:
duplicate_keys = base_cfg_dict.keys() & c.keys()
if len(duplicate_keys) > 0:
raise KeyError('Duplicate key is not allowed among bases. '
f'Duplicate keys: {duplicate_keys}')
base_cfg_dict.update(c)
# Substitute base variables from strings to their actual values
cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
base_cfg_dict)
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict
# merge cfg_text
cfg_text_list.append(cfg_text)
cfg_text = '\n'.join(cfg_text_list)
return cfg_dict, cfg_text
@staticmethod
def _merge_a_into_b(a, b, allow_list_keys=False):
"""merge dict ``a`` into dict ``b`` (non-inplace).
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
in-place modifications.
Args:
a (dict): The source dict to be merged into ``b``.
b (dict): The origin dict to be fetch keys from ``a``.
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
are allowed in source ``a`` and will replace the element of the
corresponding index in b if b is a list. Default: False.
Returns:
dict: The modified dict of ``b`` using ``a``.
Examples:
# Normally merge a into b.
>>> Config._merge_a_into_b(
... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
{'obj': {'a': 2}}
# Delete b first and merge a into b.
>>> Config._merge_a_into_b(
... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
{'obj': {'a': 2}}
# b is a list
>>> Config._merge_a_into_b(
... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
[{'a': 2}, {'b': 2}]
"""
b = b.copy()
for k, v in a.items():
if allow_list_keys and k.isdigit() and isinstance(b, list):
k = int(k)
if len(b) <= k:
raise KeyError(f'Index {k} exceeds the length of list {b}')
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
elif isinstance(v, dict):
if k in b and not v.pop(DELETE_KEY, False):
allowed_types = (dict, list) if allow_list_keys else dict
if not isinstance(b[k], allowed_types):
raise TypeError(
f'{k}={v} in child config cannot inherit from '
f'base because {k} is a dict in the child config '
f'but is of type {type(b[k])} in base config. '
f'You may set `{DELETE_KEY}=True` to ignore the '
f'base config.')
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
else:
b[k] = ConfigDict(v)
else:
b[k] = v
return b
@staticmethod
def fromfile(filename,
use_predefined_variables=True,
import_custom_modules=True):
if isinstance(filename, Path):
filename = str(filename)
cfg_dict, cfg_text = Config._file2dict(filename,
use_predefined_variables)
if import_custom_modules and cfg_dict.get('custom_imports', None):
import_modules_from_strings(**cfg_dict['custom_imports'])
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
@staticmethod
def fromstring(cfg_str, file_format):
"""Generate config from config str.
Args:
cfg_str (str): Config str.
file_format (str): Config file format corresponding to the
config str. Only py/yml/yaml/json type are supported now!
Returns:
:obj:`Config`: Config obj.
"""
if file_format not in ['.py', '.json', '.yaml', '.yml']:
raise OSError('Only py/yml/yaml/json type are supported now!')
if file_format != '.py' and 'dict(' in cfg_str:
# check if users specify a wrong suffix for python
warnings.warn(
'Please check "file_format", the file format may be .py')
with tempfile.NamedTemporaryFile(
'w', encoding='utf-8', suffix=file_format,
delete=False) as temp_file:
temp_file.write(cfg_str)
# on windows, previous implementation cause error
# see PR 1077 for details
cfg = Config.fromfile(temp_file.name)
os.remove(temp_file.name)
return cfg
@staticmethod
def auto_argparser(description=None):
"""Generate argparser from config file automatically (experimental)"""
partial_parser = ArgumentParser(description=description)
partial_parser.add_argument('config', help='config file path')
cfg_file = partial_parser.parse_known_args()[0].config
cfg = Config.fromfile(cfg_file)
parser = ArgumentParser(description=description)
parser.add_argument('config', help='config file path')
add_args(parser, cfg)
return parser, cfg
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
if cfg_dict is None:
cfg_dict = dict()
elif not isinstance(cfg_dict, dict):
raise TypeError('cfg_dict must be a dict, but '
f'got {type(cfg_dict)}')
for key in cfg_dict:
if key in RESERVED_KEYS:
raise KeyError(f'{key} is reserved for config file')
if isinstance(filename, Path):
filename = str(filename)
super().__setattr__('_cfg_dict', ConfigDict(cfg_dict))
super().__setattr__('_filename', filename)
if cfg_text:
text = cfg_text
elif filename:
with open(filename) as f:
text = f.read()
else:
text = ''
super().__setattr__('_text', text)
@property
def filename(self):
return self._filename
@property
def text(self):
return self._text
@property
def pretty_text(self):
indent = 4
def _indent(s_, num_spaces):
s = s_.split('\n')
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s
def _format_basic_types(k, v, use_mapping=False):
if isinstance(v, str):
v_str = f"'{v}'"
else:
v_str = str(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: {v_str}'
else:
attr_str = f'{str(k)}={v_str}'
attr_str = _indent(attr_str, indent)
return attr_str
def _format_list(k, v, use_mapping=False):
# check if all items in the list are dict
if all(isinstance(_, dict) for _ in v):
v_str = '[\n'
v_str += '\n'.join(
f'dict({_indent(_format_dict(v_), indent)}),'
for v_ in v).rstrip(',')
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: {v_str}'
else:
attr_str = f'{str(k)}={v_str}'
attr_str = _indent(attr_str, indent) + ']'
else:
attr_str = _format_basic_types(k, v, use_mapping)
return attr_str
def _contain_invalid_identifier(dict_str):
contain_invalid_identifier = False
for key_name in dict_str:
contain_invalid_identifier |= \
(not str(key_name).isidentifier())
return contain_invalid_identifier
def _format_dict(input_dict, outest_level=False):
r = ''
s = []
use_mapping = _contain_invalid_identifier(input_dict)
if use_mapping:
r += '{'
for idx, (k, v) in enumerate(input_dict.items()):
is_last = idx >= len(input_dict) - 1
end = '' if outest_level or is_last else ','
if isinstance(v, dict):
v_str = '\n' + _format_dict(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: dict({v_str}'
else:
attr_str = f'{str(k)}=dict({v_str}'
attr_str = _indent(attr_str, indent) + ')' + end
elif isinstance(v, list):
attr_str = _format_list(k, v, use_mapping) + end
else:
attr_str = _format_basic_types(k, v, use_mapping) + end
s.append(attr_str)
r += '\n'.join(s)
if use_mapping:
r += '}'
return r
cfg_dict = self._cfg_dict.to_dict()
text = _format_dict(cfg_dict, outest_level=True)
# copied from setup.cfg
yapf_style = dict(
based_on_style='pep8',
blank_line_before_nested_class_or_def=True,
split_before_expression_after_opening_paren=True)
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
return text
def __repr__(self):
return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
def __len__(self):
return len(self._cfg_dict)
def __getattr__(self, name):
return getattr(self._cfg_dict, name)
def __getitem__(self, name):
return self._cfg_dict.__getitem__(name)
def __setattr__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setattr__(name, value)
def __setitem__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setitem__(name, value)
def __iter__(self):
return iter(self._cfg_dict)
def __getstate__(self):
return (self._cfg_dict, self._filename, self._text)
def __copy__(self):
cls = self.__class__
other = cls.__new__(cls)
other.__dict__.update(self.__dict__)
return other
def __deepcopy__(self, memo):
cls = self.__class__
other = cls.__new__(cls)
memo[id(self)] = other
for key, value in self.__dict__.items():
super(Config, other).__setattr__(key, copy.deepcopy(value, memo))
return other
def __setstate__(self, state):
_cfg_dict, _filename, _text = state
super().__setattr__('_cfg_dict', _cfg_dict)
super().__setattr__('_filename', _filename)
super().__setattr__('_text', _text)
def dump(self, file=None):
"""Dumps config into a file or returns a string representation of the
config.
If a file argument is given, saves the config to that file using the
format defined by the file argument extension.
Otherwise, returns a string representing the config. The formatting of
this returned string is defined by the extension of `self.filename`. If
`self.filename` is not defined, returns a string representation of a
dict (lowercased and using ' for strings).
Examples:
>>> cfg_dict = dict(item1=[1, 2], item2=dict(a=0),
... item3=True, item4='test')
>>> cfg = Config(cfg_dict=cfg_dict)
>>> dump_file = "a.py"
>>> cfg.dump(dump_file)
Args:
file (str, optional): Path of the output file where the config
will be dumped. Defaults to None.
"""
cfg_dict = super().__getattribute__('_cfg_dict').to_dict()
if file is None:
if self.filename is None or self.filename.endswith('.py'):
return self.pretty_text
else:
file_format = self.filename.split('.')[-1]
return mmengine.dump(cfg_dict, file_format=file_format)
elif file.endswith('.py'):
with open(file, 'w', encoding='utf-8') as f:
f.write(self.pretty_text)
else:
file_format = file.split('.')[-1]
return mmengine.dump(cfg_dict, file=file, file_format=file_format)
def merge_from_dict(self, options, allow_list_keys=True):
"""Merge list into cfg_dict.
Merge the dict parsed by MultipleKVAction into this cfg.
Examples:
>>> options = {'model.backbone.depth': 50,
... 'model.backbone.with_cp':True}
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
>>> cfg.merge_from_dict(options)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(
... model=dict(backbone=dict(depth=50, with_cp=True)))
>>> # Merge list element
>>> cfg = Config(dict(pipeline=[
... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
>>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
>>> cfg.merge_from_dict(options, allow_list_keys=True)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(pipeline=[
... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
Args:
options (dict): dict of configs to merge from.
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
are allowed in ``options`` and will replace the element of the
corresponding index in the config if the config is a list.
Default: True.
"""
option_cfg_dict = {}
for full_key, v in options.items():
d = option_cfg_dict
key_list = full_key.split('.')
for subkey in key_list[:-1]:
d.setdefault(subkey, ConfigDict())
d = d[subkey]
subkey = key_list[-1]
d[subkey] = v
cfg_dict = super().__getattribute__('_cfg_dict')
super().__setattr__(
'_cfg_dict',
Config._merge_a_into_b(
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
class DictAction(Action):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options can
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
"""
@staticmethod
def _parse_int_float_bool(val):
try:
return int(val)
except ValueError:
pass
try:
return float(val)
except ValueError:
pass
if val.lower() in ['true', 'false']:
return True if val.lower() == 'true' else False
if val == 'None':
return None
return val
@staticmethod
def _parse_iterable(val):
"""Parse iterable values in the string.
All elements inside '()' or '[]' are treated as iterable values.
Args:
val (str): Value string.
Returns:
list | tuple: The expanded list or tuple from the string.
Examples:
>>> DictAction._parse_iterable('1,2,3')
[1, 2, 3]
>>> DictAction._parse_iterable('[a, b, c]')
['a', 'b', 'c']
>>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
[(1, 2, 3), ['a', 'b'], 'c']
"""
def find_next_comma(string):
"""Find the position of next comma in the string.
If no ',' is found in the string, return the string length. All
chars inside '()' and '[]' are treated as one element and thus ','
inside these brackets are ignored.
"""
assert (string.count('(') == string.count(')')) and (
string.count('[') == string.count(']')), \
f'Imbalanced brackets exist in {string}'
end = len(string)
for idx, char in enumerate(string):
pre = string[:idx]
# The string before this ',' is balanced
if ((char == ',') and (pre.count('(') == pre.count(')'))
and (pre.count('[') == pre.count(']'))):
end = idx
break
return end
# Strip ' and " characters and replace whitespace.
val = val.strip('\'\"').replace(' ', '')
is_tuple = False
if val.startswith('(') and val.endswith(')'):
is_tuple = True
val = val[1:-1]
elif val.startswith('[') and val.endswith(']'):
val = val[1:-1]
elif ',' not in val:
# val is a single value
return DictAction._parse_int_float_bool(val)
values = []
while len(val) > 0:
comma_idx = find_next_comma(val)
element = DictAction._parse_iterable(val[:comma_idx])
values.append(element)
val = val[comma_idx + 1:]
if is_tuple:
values = tuple(values)
return values
def __call__(self, parser, namespace, values, option_string=None):
options = {}
for kv in values:
key, val = kv.split('=', maxsplit=1)
options[key] = self._parse_iterable(val)
setattr(namespace, self.dest, options)
# Copyright (c) OpenMMLab. All rights reserved.
def is_ipu_available() -> bool:
try:
import poptorch
return poptorch.ipuHardwareIsAvailable()
except ImportError:
return False
IS_IPU_AVAILABLE = is_ipu_available()
def is_mlu_available() -> bool:
try:
import torch
return (hasattr(torch, 'is_mlu_available')
and torch.is_mlu_available())
except Exception:
return False
from mmengine.device import (is_cuda_available, is_mlu_available,
is_mps_available)
IS_MLU_AVAILABLE = is_mlu_available()
def is_mps_available() -> bool:
"""Return True if mps devices exist.
It's specialized for mac m1 chips and require torch version 1.12 or higher.
"""
try:
import torch
return hasattr(torch.backends,
'mps') and torch.backends.mps.is_available()
except Exception:
return False
IS_MPS_AVAILABLE = is_mps_available()
IS_CUDA_AVAILABLE = is_cuda_available()
# Copyright (c) OpenMMLab. All rights reserved.
"""This file holding some environment constant for sharing by other files."""
import os.path as osp
import subprocess
import sys
from collections import defaultdict
import cv2
import torch
from mmengine.utils import collect_env as mmengine_collect_env
import mmcv
from .parrots_wrapper import get_build_config
def collect_env():
......@@ -32,80 +25,12 @@ def collect_env():
``torch.__config__.show()``.
- TorchVision (optional): TorchVision version.
- OpenCV: OpenCV version.
- MMEngine: MMEngine version.
- MMCV: MMCV version.
- MMCV Compiler: The GCC version for compiling MMCV ops.
- MMCV CUDA Compiler: The CUDA version for compiling MMCV ops.
"""
env_info = {}
env_info['sys.platform'] = sys.platform
env_info['Python'] = sys.version.replace('\n', '')
cuda_available = torch.cuda.is_available()
env_info['CUDA available'] = cuda_available
if cuda_available:
devices = defaultdict(list)
for k in range(torch.cuda.device_count()):
devices[torch.cuda.get_device_name(k)].append(str(k))
for name, device_ids in devices.items():
env_info['GPU ' + ','.join(device_ids)] = name
from mmcv.utils.parrots_wrapper import _get_cuda_home
CUDA_HOME = _get_cuda_home()
env_info['CUDA_HOME'] = CUDA_HOME
if CUDA_HOME is not None and osp.isdir(CUDA_HOME):
try:
nvcc = osp.join(CUDA_HOME, 'bin/nvcc')
nvcc = subprocess.check_output(f'"{nvcc}" -V', shell=True)
nvcc = nvcc.decode('utf-8').strip()
release = nvcc.rfind('Cuda compilation tools')
build = nvcc.rfind('Build ')
nvcc = nvcc[release:build].strip()
except subprocess.SubprocessError:
nvcc = 'Not Available'
env_info['NVCC'] = nvcc
try:
# Check C++ Compiler.
# For Unix-like, sysconfig has 'CC' variable like 'gcc -pthread ...',
# indicating the compiler used, we use this to get the compiler name
import sysconfig
cc = sysconfig.get_config_var('CC')
if cc:
cc = osp.basename(cc.split()[0])
cc_info = subprocess.check_output(f'{cc} --version', shell=True)
env_info['GCC'] = cc_info.decode('utf-8').partition(
'\n')[0].strip()
else:
# on Windows, cl.exe is not in PATH. We need to find the path.
# distutils.ccompiler.new_compiler() returns a msvccompiler
# object and after initialization, path to cl.exe is found.
import locale
import os
from distutils.ccompiler import new_compiler
ccompiler = new_compiler()
ccompiler.initialize()
cc = subprocess.check_output(
f'{ccompiler.cc}', stderr=subprocess.STDOUT, shell=True)
encoding = os.device_encoding(
sys.stdout.fileno()) or locale.getpreferredencoding()
env_info['MSVC'] = cc.decode(encoding).partition('\n')[0].strip()
env_info['GCC'] = 'n/a'
except subprocess.CalledProcessError:
env_info['GCC'] = 'n/a'
env_info['PyTorch'] = torch.__version__
env_info['PyTorch compiling details'] = get_build_config()
try:
import torchvision
env_info['TorchVision'] = torchvision.__version__
except ModuleNotFoundError:
pass
env_info['OpenCV'] = cv2.__version__
env_info = mmengine_collect_env()
env_info['MMCV'] = mmcv.__version__
try:
......
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import torch.distributed as dist
logger_initialized: dict = {}
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified and the process rank is 0, a FileHandler
will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
file_mode (str): The file mode used in opening log file.
Defaults to 'w'.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
# handle hierarchical names
# e.g., logger "a" is initialized, then logger "a.b" will skip the
# initialization since it is a child of "a".
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
# handle duplicate logs to the console
# Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
# to the root logger. As logger.propagate is True by default, this root
# level handler causes logging messages from rank>0 processes to
# unexpectedly show up on the console, creating much unwanted clutter.
# To fix this issue, we set the root logger's StreamHandler, if any, to log
# at the ERROR level.
for handler in logger.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.ERROR)
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
# only rank 0 will add a FileHandler
if rank == 0 and log_file is not None:
# Here, the default behaviour of the official logger is 'a'. Thus, we
# provide an interface to change the file mode to the default
# behaviour.
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)
if rank == 0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
return logger
def print_log(msg, logger=None, level=logging.INFO):
"""Print a log message.
Args:
msg (str): The message to be logged.
logger (logging.Logger | str | None): The logger to be used.
Some special loggers are:
- "silent": no message will be printed.
- other str: the logger obtained with `get_root_logger(logger)`.
- None: The `print()` method will be used to print log messages.
level (int): Logging level. Only available when `logger` is a Logger
object or "root".
"""
if logger is None:
print(msg)
elif isinstance(logger, logging.Logger):
logger.log(level, msg)
elif logger == 'silent':
pass
elif isinstance(logger, str):
_logger = get_logger(logger)
_logger.log(level, msg)
else:
raise TypeError(
'logger should be either a logging.Logger object, str, '
f'"silent" or None, but got {type(logger)}')
# Copyright (c) OpenMMLab. All rights reserved.
import collections.abc
import functools
import itertools
import subprocess
import warnings
from collections import abc
from importlib import import_module
from inspect import getfullargspec
from itertools import repeat
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def is_str(x):
"""Whether the input is an string instance.
Note: This method is deprecated since python 2 is no longer supported.
"""
return isinstance(x, str)
def import_modules_from_strings(imports, allow_failed_imports=False):
"""Import modules from the given list of strings.
Args:
imports (list | str | None): The given module names to be imported.
allow_failed_imports (bool): If True, the failed imports will return
None. Otherwise, an ImportError is raise. Default: False.
Returns:
list[module] | module | None: The imported modules.
Examples:
>>> osp, sys = import_modules_from_strings(
... ['os.path', 'sys'])
>>> import os.path as osp_
>>> import sys as sys_
>>> assert osp == osp_
>>> assert sys == sys_
"""
if not imports:
return
single_import = False
if isinstance(imports, str):
single_import = True
imports = [imports]
if not isinstance(imports, list):
raise TypeError(
f'custom_imports must be a list but got type {type(imports)}')
imported = []
for imp in imports:
if not isinstance(imp, str):
raise TypeError(
f'{imp} is of type {type(imp)} and cannot be imported.')
try:
imported_tmp = import_module(imp)
except ImportError:
if allow_failed_imports:
warnings.warn(f'{imp} failed to import and is ignored.',
UserWarning)
imported_tmp = None
else:
raise ImportError
imported.append(imported_tmp)
if single_import:
imported = imported[0]
return imported
def iter_cast(inputs, dst_type, return_type=None):
"""Cast elements of an iterable object into some type.
Args:
inputs (Iterable): The input object.
dst_type (type): Destination type.
return_type (type, optional): If specified, the output object will be
converted to this type, otherwise an iterator.
Returns:
iterator or specified type: The converted object.
"""
if not isinstance(inputs, abc.Iterable):
raise TypeError('inputs must be an iterable object')
if not isinstance(dst_type, type):
raise TypeError('"dst_type" must be a valid type')
out_iterable = map(dst_type, inputs)
if return_type is None:
return out_iterable
else:
return return_type(out_iterable)
def list_cast(inputs, dst_type):
"""Cast elements of an iterable object into a list of some type.
A partial method of :func:`iter_cast`.
"""
return iter_cast(inputs, dst_type, return_type=list)
def tuple_cast(inputs, dst_type):
"""Cast elements of an iterable object into a tuple of some type.
A partial method of :func:`iter_cast`.
"""
return iter_cast(inputs, dst_type, return_type=tuple)
def is_seq_of(seq, expected_type, seq_type=None):
"""Check whether it is a sequence of some type.
Args:
seq (Sequence): The sequence to be checked.
expected_type (type): Expected type of sequence items.
seq_type (type, optional): Expected sequence type.
Returns:
bool: Whether the sequence is valid.
"""
if seq_type is None:
exp_seq_type = abc.Sequence
else:
assert isinstance(seq_type, type)
exp_seq_type = seq_type
if not isinstance(seq, exp_seq_type):
return False
for item in seq:
if not isinstance(item, expected_type):
return False
return True
def is_list_of(seq, expected_type):
"""Check whether it is a list of some type.
A partial method of :func:`is_seq_of`.
"""
return is_seq_of(seq, expected_type, seq_type=list)
def is_tuple_of(seq, expected_type):
"""Check whether it is a tuple of some type.
A partial method of :func:`is_seq_of`.
"""
return is_seq_of(seq, expected_type, seq_type=tuple)
def slice_list(in_list, lens):
"""Slice a list into several sub lists by a list of given length.
Args:
in_list (list): The list to be sliced.
lens(int or list): The expected length of each out list.
Returns:
list: A list of sliced list.
"""
if isinstance(lens, int):
assert len(in_list) % lens == 0
lens = [lens] * int(len(in_list) / lens)
if not isinstance(lens, list):
raise TypeError('"indices" must be an integer or a list of integers')
elif sum(lens) != len(in_list):
raise ValueError('sum of lens and list length does not '
f'match: {sum(lens)} != {len(in_list)}')
out_list = []
idx = 0
for i in range(len(lens)):
out_list.append(in_list[idx:idx + lens[i]])
idx += lens[i]
return out_list
def concat_list(in_list):
"""Concatenate a list of list into a single list.
Args:
in_list (list): The list of list to be merged.
Returns:
list: The concatenated flat list.
"""
return list(itertools.chain(*in_list))
def check_prerequisites(
prerequisites,
checker,
msg_tmpl='Prerequisites "{}" are required in method "{}" but not '
'found, please install them first.'): # yapf: disable
"""A decorator factory to check if prerequisites are satisfied.
Args:
prerequisites (str of list[str]): Prerequisites to be checked.
checker (callable): The checker method that returns True if a
prerequisite is meet, False otherwise.
msg_tmpl (str): The message template with two variables.
Returns:
decorator: A specific decorator.
"""
def wrap(func):
@functools.wraps(func)
def wrapped_func(*args, **kwargs):
requirements = [prerequisites] if isinstance(
prerequisites, str) else prerequisites
missing = []
for item in requirements:
if not checker(item):
missing.append(item)
if missing:
print(msg_tmpl.format(', '.join(missing), func.__name__))
raise RuntimeError('Prerequisites not meet.')
else:
return func(*args, **kwargs)
return wrapped_func
return wrap
def _check_py_package(package):
try:
import_module(package)
except ImportError:
return False
else:
return True
def _check_executable(cmd):
if subprocess.call(f'which {cmd}', shell=True) != 0:
return False
else:
return True
def requires_package(prerequisites):
"""A decorator to check if some python packages are installed.
Example:
>>> @requires_package('numpy')
>>> func(arg1, args):
>>> return numpy.zeros(1)
array([0.])
>>> @requires_package(['numpy', 'non_package'])
>>> func(arg1, args):
>>> return numpy.zeros(1)
ImportError
"""
return check_prerequisites(prerequisites, checker=_check_py_package)
def requires_executable(prerequisites):
"""A decorator to check if some executable files are installed.
Example:
>>> @requires_executable('ffmpeg')
>>> func(arg1, args):
>>> print(1)
1
"""
return check_prerequisites(prerequisites, checker=_check_executable)
def deprecated_api_warning(name_dict, cls_name=None):
"""A decorator to check if some arguments are deprecate and try to replace
deprecate src_arg_name to dst_arg_name.
Args:
name_dict(dict):
key (str): Deprecate argument names.
val (str): Expected argument names.
Returns:
func: New function.
"""
def api_warning_wrapper(old_func):
@functools.wraps(old_func)
def new_func(*args, **kwargs):
# get the arg spec of the decorated method
args_info = getfullargspec(old_func)
# get name of the function
func_name = old_func.__name__
if cls_name is not None:
func_name = f'{cls_name}.{func_name}'
if args:
arg_names = args_info.args[:len(args)]
for src_arg_name, dst_arg_name in name_dict.items():
if src_arg_name in arg_names:
warnings.warn(
f'"{src_arg_name}" is deprecated in '
f'`{func_name}`, please use "{dst_arg_name}" '
'instead', DeprecationWarning)
arg_names[arg_names.index(src_arg_name)] = dst_arg_name
if kwargs:
for src_arg_name, dst_arg_name in name_dict.items():
if src_arg_name in kwargs:
assert dst_arg_name not in kwargs, (
f'The expected behavior is to replace '
f'the deprecated key `{src_arg_name}` to '
f'new key `{dst_arg_name}`, but got them '
f'in the arguments at the same time, which '
f'is confusing. `{src_arg_name} will be '
f'deprecated in the future, please '
f'use `{dst_arg_name}` instead.')
warnings.warn(
f'"{src_arg_name}" is deprecated in '
f'`{func_name}`, please use "{dst_arg_name}" '
'instead', DeprecationWarning)
kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
# apply converted arguments to the decorated method
output = old_func(*args, **kwargs)
return output
return new_func
return api_warning_wrapper
def is_method_overridden(method, base_class, derived_class):
"""Check if a method of base class is overridden in derived class.
Args:
method (str): the method name to check.
base_class (type): the class of the base class.
derived_class (type | Any): the class or instance of the derived class.
"""
assert isinstance(base_class, type), \
"base_class doesn't accept instance, Please pass class instead."
if not isinstance(derived_class, type):
derived_class = derived_class.__class__
base_method = getattr(base_class, method)
derived_method = getattr(derived_class, method)
return derived_method != base_method
def has_method(obj: object, method: str) -> bool:
"""Check whether the object has a method.
Args:
method (str): The method name to check.
obj (object): The object to check.
Returns:
bool: True if the object has the method else False.
"""
return hasattr(obj, method) and callable(getattr(obj, method))
# Copyright (c) OpenMMLab. All rights reserved.
import os
from .parrots_wrapper import TORCH_VERSION
from mmengine.utils.parrots_wrapper import TORCH_VERSION
parrots_jit_option = os.getenv('PARROTS_JIT_OPTION')
......
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
import torch
TORCH_VERSION = torch.__version__
def is_cuda_available() -> bool:
return torch.cuda.is_available()
IS_CUDA_AVAILABLE = is_cuda_available()
def is_rocm_pytorch() -> bool:
is_rocm = False
if TORCH_VERSION != 'parrots':
try:
from torch.utils.cpp_extension import ROCM_HOME
is_rocm = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False
except ImportError:
pass
return is_rocm
def _get_cuda_home():
if TORCH_VERSION == 'parrots':
from parrots.utils.build_extension import CUDA_HOME
else:
if is_rocm_pytorch():
from torch.utils.cpp_extension import ROCM_HOME
CUDA_HOME = ROCM_HOME
else:
from torch.utils.cpp_extension import CUDA_HOME
return CUDA_HOME
def get_build_config():
if TORCH_VERSION == 'parrots':
from parrots.config import get_build_info
return get_build_info()
else:
return torch.__config__.show()
def _get_conv():
if TORCH_VERSION == 'parrots':
from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin
else:
from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
return _ConvNd, _ConvTransposeMixin
def _get_dataloader():
if TORCH_VERSION == 'parrots':
from torch.utils.data import DataLoader, PoolDataLoader
else:
from torch.utils.data import DataLoader
PoolDataLoader = DataLoader
return DataLoader, PoolDataLoader
def _get_extension():
if TORCH_VERSION == 'parrots':
from parrots.utils.build_extension import BuildExtension, Extension
CppExtension = partial(Extension, cuda=False)
CUDAExtension = partial(Extension, cuda=True)
else:
from torch.utils.cpp_extension import (BuildExtension, CppExtension,
CUDAExtension)
return BuildExtension, CppExtension, CUDAExtension
def _get_pool():
if TORCH_VERSION == 'parrots':
from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd,
_MaxPoolNd)
else:
from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd,
_MaxPoolNd)
return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd
def _get_norm():
if TORCH_VERSION == 'parrots':
from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm
SyncBatchNorm_ = torch.nn.SyncBatchNorm2d
else:
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm
SyncBatchNorm_ = torch.nn.SyncBatchNorm
return _BatchNorm, _InstanceNorm, SyncBatchNorm_
_ConvNd, _ConvTransposeMixin = _get_conv()
DataLoader, PoolDataLoader = _get_dataloader()
BuildExtension, CppExtension, CUDAExtension = _get_extension()
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
class SyncBatchNorm(SyncBatchNorm_): # type: ignore
def _check_input_dim(self, input):
if TORCH_VERSION == 'parrots':
if input.dim() < 2:
raise ValueError(
f'expected at least 2D input (got {input.dim()}D input)')
else:
super()._check_input_dim(input)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment