Unverified Commit 5947178e authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

Remove many functions in utils and migrate them to mmengine (#2217)

* Remove runner, parallel, engine and device

* fix format

* remove outdated docs

* migrate many functions to mmengine

* remove sync_bn.py
parent 9185eee8
...@@ -6,11 +6,11 @@ import torch ...@@ -6,11 +6,11 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmengine import print_log from mmengine import print_log
from mmengine.registry import MODELS from mmengine.registry import MODELS
from mmengine.utils import deprecated_api_warning
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single from torch.nn.modules.utils import _pair, _single
from mmcv.utils import deprecated_api_warning
from ..utils import ext_loader from ..utils import ext_loader
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
......
...@@ -3,16 +3,16 @@ import math ...@@ -3,16 +3,16 @@ import math
import warnings import warnings
from typing import Optional, no_type_check from typing import Optional, no_type_check
import mmengine
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmengine.model import BaseModule from mmengine.model import BaseModule
from mmengine.model.utils import constant_init, xavier_init from mmengine.model.utils import constant_init, xavier_init
from mmengine.registry import MODELS from mmengine.registry import MODELS
from mmengine.utils import deprecated_api_warning
from torch.autograd.function import Function, once_differentiable from torch.autograd.function import Function, once_differentiable
import mmcv
from mmcv import deprecated_api_warning
from ..utils import ext_loader from ..utils import ext_loader
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
...@@ -193,7 +193,7 @@ class MultiScaleDeformableAttention(BaseModule): ...@@ -193,7 +193,7 @@ class MultiScaleDeformableAttention(BaseModule):
dropout: float = 0.1, dropout: float = 0.1,
batch_first: bool = False, batch_first: bool = False,
norm_cfg: Optional[dict] = None, norm_cfg: Optional[dict] = None,
init_cfg: Optional[mmcv.ConfigDict] = None): init_cfg: Optional[mmengine.ConfigDict] = None):
super().__init__(init_cfg) super().__init__(init_cfg)
if embed_dims % num_heads != 0: if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, ' raise ValueError(f'embed_dims must be divisible by num_heads, '
......
...@@ -3,9 +3,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -3,9 +3,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from mmengine.utils import deprecated_api_warning
from torch import Tensor from torch import Tensor
from mmcv.utils import deprecated_api_warning
from ..utils import ext_loader from ..utils import ext_loader
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
......
...@@ -3,9 +3,10 @@ from typing import Any, Optional, Tuple, Union ...@@ -3,9 +3,10 @@ from typing import Any, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.utils import is_tuple_of
from torch.autograd import Function from torch.autograd import Function
from ..utils import ext_loader, is_tuple_of from ..utils import ext_loader
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
'_ext', ['riroi_align_rotated_forward', 'riroi_align_rotated_backward']) '_ext', ['riroi_align_rotated_forward', 'riroi_align_rotated_backward'])
......
...@@ -3,11 +3,12 @@ from typing import Any ...@@ -3,11 +3,12 @@ from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.utils import deprecated_api_warning
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from ..utils import deprecated_api_warning, ext_loader from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ext_module = ext_loader.load_ext('_ext',
['roi_align_forward', 'roi_align_backward']) ['roi_align_forward', 'roi_align_backward'])
......
...@@ -3,10 +3,11 @@ from typing import Any, Optional, Tuple, Union ...@@ -3,10 +3,11 @@ from typing import Any, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.utils import deprecated_api_warning
from torch.autograd import Function from torch.autograd import Function
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from ..utils import deprecated_api_warning, ext_loader from ..utils import ext_loader
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
'_ext', ['roi_align_rotated_forward', 'roi_align_rotated_backward']) '_ext', ['roi_align_rotated_forward', 'roi_align_rotated_backward'])
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Tuple, Union from typing import Any, Tuple, Union
import mmengine
import torch import torch
from torch import nn as nn from torch import nn as nn
from torch.autograd import Function from torch.autograd import Function
import mmcv
from ..utils import ext_loader from ..utils import ext_loader
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
...@@ -86,7 +86,7 @@ class RoIAwarePool3dFunction(Function): ...@@ -86,7 +86,7 @@ class RoIAwarePool3dFunction(Function):
out_x = out_y = out_z = out_size out_x = out_y = out_z = out_size
else: else:
assert len(out_size) == 3 assert len(out_size) == 3
assert mmcv.is_tuple_of(out_size, int) assert mmengine.is_tuple_of(out_size, int)
out_x, out_y, out_z = out_size out_x, out_y, out_z = out_size
num_rois = rois.shape[0] num_rois = rois.shape[0]
......
...@@ -4,10 +4,10 @@ import torch.nn as nn ...@@ -4,10 +4,10 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmengine.model.utils import constant_init from mmengine.model.utils import constant_init
from mmengine.registry import MODELS from mmengine.registry import MODELS
from mmengine.utils import TORCH_VERSION, digit_version
from mmcv.cnn import ConvAWS2d from mmcv.cnn import ConvAWS2d
from mmcv.ops.deform_conv import deform_conv2d from mmcv.ops.deform_conv import deform_conv2d
from mmcv.utils import TORCH_VERSION, digit_version
@MODELS.register_module(name='SAC') @MODELS.register_module(name='SAC')
......
...@@ -98,10 +98,10 @@ ...@@ -98,10 +98,10 @@
from typing import Any, List, Tuple, Union from typing import Any, List, Tuple, Union
import torch import torch
from mmengine.utils import to_2tuple
from torch.autograd import Function from torch.autograd import Function
from torch.nn import functional as F from torch.nn import functional as F
from mmcv.utils import to_2tuple
from ..utils import ext_loader from ..utils import ext_loader
upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d']) upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d'])
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Union from typing import Sequence, Union
import mmengine
import numpy as np import numpy as np
import torch import torch
import mmcv
from .base import BaseTransform from .base import BaseTransform
from .builder import TRANSFORMS from .builder import TRANSFORMS
...@@ -29,7 +29,7 @@ def to_tensor( ...@@ -29,7 +29,7 @@ def to_tensor(
return data return data
elif isinstance(data, np.ndarray): elif isinstance(data, np.ndarray):
return torch.from_numpy(data) return torch.from_numpy(data)
elif isinstance(data, Sequence) and not mmcv.is_str(data): elif isinstance(data, Sequence) and not mmengine.is_str(data):
return torch.tensor(data) return torch.tensor(data)
elif isinstance(data, int): elif isinstance(data, int):
return torch.LongTensor([data]) return torch.LongTensor([data])
......
...@@ -3,6 +3,7 @@ import random ...@@ -3,6 +3,7 @@ import random
import warnings import warnings
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import mmengine
import numpy as np import numpy as np
import mmcv import mmcv
...@@ -797,7 +798,7 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -797,7 +798,7 @@ class MultiScaleFlipAug(BaseTransform):
if scales is not None: if scales is not None:
self.scales = scales if isinstance(scales, list) else [scales] self.scales = scales if isinstance(scales, list) else [scales]
self.scale_key = 'scale' self.scale_key = 'scale'
assert mmcv.is_list_of(self.scales, tuple) assert mmengine.is_list_of(self.scales, tuple)
else: else:
# if ``scales`` and ``scale_factor`` both be ``None`` # if ``scales`` and ``scale_factor`` both be ``None``
if scale_factor is None: if scale_factor is None:
...@@ -812,7 +813,7 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -812,7 +813,7 @@ class MultiScaleFlipAug(BaseTransform):
self.allow_flip = allow_flip self.allow_flip = allow_flip
self.flip_direction = flip_direction if isinstance( self.flip_direction = flip_direction if isinstance(
flip_direction, list) else [flip_direction] flip_direction, list) else [flip_direction]
assert mmcv.is_list_of(self.flip_direction, str) assert mmengine.is_list_of(self.flip_direction, str)
if not self.allow_flip and self.flip_direction != ['horizontal']: if not self.allow_flip and self.flip_direction != ['horizontal']:
warnings.warn( warnings.warn(
'flip_direction has no effect when flip is set to False') 'flip_direction has no effect when flip is set to False')
...@@ -934,7 +935,7 @@ class RandomChoiceResize(BaseTransform): ...@@ -934,7 +935,7 @@ class RandomChoiceResize(BaseTransform):
self.scales = scales self.scales = scales
else: else:
self.scales = [scales] self.scales = [scales]
assert mmcv.is_list_of(self.scales, tuple) assert mmengine.is_list_of(self.scales, tuple)
self.resize_cfg = dict(type=resize_type, **resize_kwargs) self.resize_cfg = dict(type=resize_type, **resize_kwargs)
# create a empty Resize object # create a empty Resize object
...@@ -950,7 +951,7 @@ class RandomChoiceResize(BaseTransform): ...@@ -950,7 +951,7 @@ class RandomChoiceResize(BaseTransform):
``scale_idx`` is the selected index in the given candidates. ``scale_idx`` is the selected index in the given candidates.
""" """
assert mmcv.is_list_of(self.scales, tuple) assert mmengine.is_list_of(self.scales, tuple)
scale_idx = np.random.randint(len(self.scales)) scale_idx = np.random.randint(len(self.scales))
scale = self.scales[scale_idx] scale = self.scales[scale_idx]
return scale, scale_idx return scale, scale_idx
...@@ -1033,7 +1034,7 @@ class RandomFlip(BaseTransform): ...@@ -1033,7 +1034,7 @@ class RandomFlip(BaseTransform):
direction: Union[str, direction: Union[str,
Sequence[Optional[str]]] = 'horizontal') -> None: Sequence[Optional[str]]] = 'horizontal') -> None:
if isinstance(prob, list): if isinstance(prob, list):
assert mmcv.is_list_of(prob, float) assert mmengine.is_list_of(prob, float)
assert 0 <= sum(prob) <= 1 assert 0 <= sum(prob) <= 1
elif isinstance(prob, float): elif isinstance(prob, float):
assert 0 <= prob <= 1 assert 0 <= prob <= 1
...@@ -1046,7 +1047,7 @@ class RandomFlip(BaseTransform): ...@@ -1046,7 +1047,7 @@ class RandomFlip(BaseTransform):
if isinstance(direction, str): if isinstance(direction, str):
assert direction in valid_directions assert direction in valid_directions
elif isinstance(direction, list): elif isinstance(direction, list):
assert mmcv.is_list_of(direction, str) assert mmengine.is_list_of(direction, str)
assert set(direction).issubset(set(valid_directions)) assert set(direction).issubset(set(valid_directions))
else: else:
raise ValueError(f'direction must be either str or list of str, \ raise ValueError(f'direction must be either str or list of str, \
...@@ -1308,7 +1309,7 @@ class RandomResize(BaseTransform): ...@@ -1308,7 +1309,7 @@ class RandomResize(BaseTransform):
tuple: The targeted scale of the image to be resized. tuple: The targeted scale of the image to be resized.
""" """
assert mmcv.is_list_of(scales, tuple) and len(scales) == 2 assert mmengine.is_list_of(scales, tuple) and len(scales) == 2
scale_0 = [scales[0][0], scales[1][0]] scale_0 = [scales[0][0], scales[1][0]]
scale_1 = [scales[0][1], scales[1][1]] scale_1 = [scales[0][1], scales[1][1]]
edge_0 = np.random.randint(min(scale_0), max(scale_0) + 1) edge_0 = np.random.randint(min(scale_0), max(scale_0) + 1)
...@@ -1350,12 +1351,12 @@ class RandomResize(BaseTransform): ...@@ -1350,12 +1351,12 @@ class RandomResize(BaseTransform):
tuple: The targeted scale of the image to be resized. tuple: The targeted scale of the image to be resized.
""" """
if mmcv.is_tuple_of(self.scale, int): if mmengine.is_tuple_of(self.scale, int):
assert self.ratio_range is not None and len(self.ratio_range) == 2 assert self.ratio_range is not None and len(self.ratio_range) == 2
scale = self._random_sample_ratio( scale = self._random_sample_ratio(
self.scale, # type: ignore self.scale, # type: ignore
self.ratio_range) self.ratio_range)
elif mmcv.is_seq_of(self.scale, tuple): elif mmengine.is_seq_of(self.scale, tuple):
scale = self._random_sample(self.scale) # type: ignore scale = self._random_sample(self.scale) # type: ignore
else: else:
raise NotImplementedError('Do not support sampling function ' raise NotImplementedError('Do not support sampling function '
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import mmengine
import numpy as np import numpy as np
import mmcv
from .base import BaseTransform from .base import BaseTransform
from .builder import TRANSFORMS from .builder import TRANSFORMS
from .utils import cache_random_params, cache_randomness from .utils import cache_random_params, cache_randomness
...@@ -569,7 +569,7 @@ class RandomChoice(BaseTransform): ...@@ -569,7 +569,7 @@ class RandomChoice(BaseTransform):
super().__init__() super().__init__()
if prob is not None: if prob is not None:
assert mmcv.is_seq_of(prob, float) assert mmengine.is_seq_of(prob, float)
assert len(transforms) == len(prob), \ assert len(transforms) == len(prob), \
'``transforms`` and ``prob`` must have same lengths. ' \ '``transforms`` and ``prob`` must have same lengths. ' \
f'Got {len(transforms)} vs {len(prob)}.' f'Got {len(transforms)} vs {len(prob)}.'
......
# flake8: noqa
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .config import Config, ConfigDict, DictAction from .device_type import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
from .misc import (check_prerequisites, concat_list, deprecated_api_warning, from .env import collect_env
has_method, import_modules_from_strings, is_list_of, from .parrots_jit import jit, skip_no_elena
is_method_overridden, is_seq_of, is_str, is_tuple_of,
iter_cast, list_cast, requires_executable, requires_package,
slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
to_ntuple, tuple_cast)
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress,
track_parallel_progress, track_progress)
from .testing import (assert_attrs_equal, assert_dict_contains_subset,
assert_dict_has_keys, assert_is_norm_layer,
assert_keys_equal, assert_params_all_zeros,
check_python_script)
from .timer import Timer, TimerError, check_time
from .version_utils import digit_version, get_git_hash
try: __all__ = [
import torch 'IS_MLU_AVAILABLE', 'IS_MPS_AVAILABLE', 'IS_CUDA_AVAILABLE', 'collect_env',
except ImportError: 'jit', 'skip_no_elena'
__all__ = [ ]
'Config', 'ConfigDict', 'DictAction', 'is_str', 'iter_cast',
'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of', 'is_tuple_of',
'slice_list', 'concat_list', 'check_prerequisites', 'requires_package',
'requires_executable', 'is_filepath', 'fopen', 'check_file_exist',
'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar',
'track_progress', 'track_iter_progress', 'track_parallel_progress',
'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
'digit_version', 'get_git_hash', 'import_modules_from_strings',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script',
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
'is_method_overridden', 'has_method'
]
else:
from .device_type import (IS_IPU_AVAILABLE, IS_MLU_AVAILABLE,
IS_MPS_AVAILABLE)
from .env import collect_env
from .logging import get_logger, print_log
from .parrots_jit import jit, skip_no_elena
# yapf: disable
from .parrots_wrapper import (IS_CUDA_AVAILABLE, TORCH_VERSION,
BuildExtension, CppExtension, CUDAExtension,
DataLoader, PoolDataLoader, SyncBatchNorm,
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
_AvgPoolNd, _BatchNorm, _ConvNd,
_ConvTransposeMixin, _get_cuda_home,
_InstanceNorm, _MaxPoolNd, get_build_config,
is_rocm_pytorch)
# yapf: enable
from .registry import Registry, build_from_cfg
from .seed import worker_init_fn
from .torch_ops import torch_meshgrid
from .trace import is_jit_tracing
__all__ = [
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
'is_seq_of', 'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list',
'check_prerequisites', 'requires_package', 'requires_executable',
'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist',
'symlink', 'scandir', 'ProgressBar', 'track_progress',
'track_iter_progress', 'track_parallel_progress', 'Registry',
'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'SyncBatchNorm',
'_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm',
'_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd',
'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension',
'DataLoader', 'PoolDataLoader', 'TORCH_VERSION',
'deprecated_api_warning', 'digit_version', 'get_git_hash',
'import_modules_from_strings', 'jit', 'skip_no_elena',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
'assert_params_all_zeros', 'check_python_script',
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
'_get_cuda_home', 'has_method', 'IS_CUDA_AVAILABLE', 'worker_init_fn',
'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE', 'IS_MPS_AVAILABLE',
'torch_meshgrid'
]
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmengine.device import (is_cuda_available, is_mlu_available,
is_mps_available)
def is_ipu_available() -> bool:
try:
import poptorch
return poptorch.ipuHardwareIsAvailable()
except ImportError:
return False
IS_IPU_AVAILABLE = is_ipu_available()
def is_mlu_available() -> bool:
try:
import torch
return (hasattr(torch, 'is_mlu_available')
and torch.is_mlu_available())
except Exception:
return False
IS_MLU_AVAILABLE = is_mlu_available() IS_MLU_AVAILABLE = is_mlu_available()
def is_mps_available() -> bool:
"""Return True if mps devices exist.
It's specialized for mac m1 chips and require torch version 1.12 or higher.
"""
try:
import torch
return hasattr(torch.backends,
'mps') and torch.backends.mps.is_available()
except Exception:
return False
IS_MPS_AVAILABLE = is_mps_available() IS_MPS_AVAILABLE = is_mps_available()
IS_CUDA_AVAILABLE = is_cuda_available()
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
"""This file holding some environment constant for sharing by other files.""" """This file holding some environment constant for sharing by other files."""
import os.path as osp from mmengine.utils import collect_env as mmengine_collect_env
import subprocess
import sys
from collections import defaultdict
import cv2
import torch
import mmcv import mmcv
from .parrots_wrapper import get_build_config
def collect_env(): def collect_env():
...@@ -32,80 +25,12 @@ def collect_env(): ...@@ -32,80 +25,12 @@ def collect_env():
``torch.__config__.show()``. ``torch.__config__.show()``.
- TorchVision (optional): TorchVision version. - TorchVision (optional): TorchVision version.
- OpenCV: OpenCV version. - OpenCV: OpenCV version.
- MMEngine: MMEngine version.
- MMCV: MMCV version. - MMCV: MMCV version.
- MMCV Compiler: The GCC version for compiling MMCV ops. - MMCV Compiler: The GCC version for compiling MMCV ops.
- MMCV CUDA Compiler: The CUDA version for compiling MMCV ops. - MMCV CUDA Compiler: The CUDA version for compiling MMCV ops.
""" """
env_info = {} env_info = mmengine_collect_env()
env_info['sys.platform'] = sys.platform
env_info['Python'] = sys.version.replace('\n', '')
cuda_available = torch.cuda.is_available()
env_info['CUDA available'] = cuda_available
if cuda_available:
devices = defaultdict(list)
for k in range(torch.cuda.device_count()):
devices[torch.cuda.get_device_name(k)].append(str(k))
for name, device_ids in devices.items():
env_info['GPU ' + ','.join(device_ids)] = name
from mmcv.utils.parrots_wrapper import _get_cuda_home
CUDA_HOME = _get_cuda_home()
env_info['CUDA_HOME'] = CUDA_HOME
if CUDA_HOME is not None and osp.isdir(CUDA_HOME):
try:
nvcc = osp.join(CUDA_HOME, 'bin/nvcc')
nvcc = subprocess.check_output(f'"{nvcc}" -V', shell=True)
nvcc = nvcc.decode('utf-8').strip()
release = nvcc.rfind('Cuda compilation tools')
build = nvcc.rfind('Build ')
nvcc = nvcc[release:build].strip()
except subprocess.SubprocessError:
nvcc = 'Not Available'
env_info['NVCC'] = nvcc
try:
# Check C++ Compiler.
# For Unix-like, sysconfig has 'CC' variable like 'gcc -pthread ...',
# indicating the compiler used, we use this to get the compiler name
import sysconfig
cc = sysconfig.get_config_var('CC')
if cc:
cc = osp.basename(cc.split()[0])
cc_info = subprocess.check_output(f'{cc} --version', shell=True)
env_info['GCC'] = cc_info.decode('utf-8').partition(
'\n')[0].strip()
else:
# on Windows, cl.exe is not in PATH. We need to find the path.
# distutils.ccompiler.new_compiler() returns a msvccompiler
# object and after initialization, path to cl.exe is found.
import locale
import os
from distutils.ccompiler import new_compiler
ccompiler = new_compiler()
ccompiler.initialize()
cc = subprocess.check_output(
f'{ccompiler.cc}', stderr=subprocess.STDOUT, shell=True)
encoding = os.device_encoding(
sys.stdout.fileno()) or locale.getpreferredencoding()
env_info['MSVC'] = cc.decode(encoding).partition('\n')[0].strip()
env_info['GCC'] = 'n/a'
except subprocess.CalledProcessError:
env_info['GCC'] = 'n/a'
env_info['PyTorch'] = torch.__version__
env_info['PyTorch compiling details'] = get_build_config()
try:
import torchvision
env_info['TorchVision'] = torchvision.__version__
except ModuleNotFoundError:
pass
env_info['OpenCV'] = cv2.__version__
env_info['MMCV'] = mmcv.__version__ env_info['MMCV'] = mmcv.__version__
try: try:
......
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import torch.distributed as dist
logger_initialized: dict = {}
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified and the process rank is 0, a FileHandler
will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
file_mode (str): The file mode used in opening log file.
Defaults to 'w'.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
# handle hierarchical names
# e.g., logger "a" is initialized, then logger "a.b" will skip the
# initialization since it is a child of "a".
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
# handle duplicate logs to the console
# Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
# to the root logger. As logger.propagate is True by default, this root
# level handler causes logging messages from rank>0 processes to
# unexpectedly show up on the console, creating much unwanted clutter.
# To fix this issue, we set the root logger's StreamHandler, if any, to log
# at the ERROR level.
for handler in logger.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.ERROR)
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
# only rank 0 will add a FileHandler
if rank == 0 and log_file is not None:
# Here, the default behaviour of the official logger is 'a'. Thus, we
# provide an interface to change the file mode to the default
# behaviour.
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)
if rank == 0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
return logger
def print_log(msg, logger=None, level=logging.INFO):
"""Print a log message.
Args:
msg (str): The message to be logged.
logger (logging.Logger | str | None): The logger to be used.
Some special loggers are:
- "silent": no message will be printed.
- other str: the logger obtained with `get_root_logger(logger)`.
- None: The `print()` method will be used to print log messages.
level (int): Logging level. Only available when `logger` is a Logger
object or "root".
"""
if logger is None:
print(msg)
elif isinstance(logger, logging.Logger):
logger.log(level, msg)
elif logger == 'silent':
pass
elif isinstance(logger, str):
_logger = get_logger(logger)
_logger.log(level, msg)
else:
raise TypeError(
'logger should be either a logging.Logger object, str, '
f'"silent" or None, but got {type(logger)}')
# Copyright (c) OpenMMLab. All rights reserved.
import collections.abc
import functools
import itertools
import subprocess
import warnings
from collections import abc
from importlib import import_module
from inspect import getfullargspec
from itertools import repeat
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def is_str(x):
"""Whether the input is an string instance.
Note: This method is deprecated since python 2 is no longer supported.
"""
return isinstance(x, str)
def import_modules_from_strings(imports, allow_failed_imports=False):
"""Import modules from the given list of strings.
Args:
imports (list | str | None): The given module names to be imported.
allow_failed_imports (bool): If True, the failed imports will return
None. Otherwise, an ImportError is raise. Default: False.
Returns:
list[module] | module | None: The imported modules.
Examples:
>>> osp, sys = import_modules_from_strings(
... ['os.path', 'sys'])
>>> import os.path as osp_
>>> import sys as sys_
>>> assert osp == osp_
>>> assert sys == sys_
"""
if not imports:
return
single_import = False
if isinstance(imports, str):
single_import = True
imports = [imports]
if not isinstance(imports, list):
raise TypeError(
f'custom_imports must be a list but got type {type(imports)}')
imported = []
for imp in imports:
if not isinstance(imp, str):
raise TypeError(
f'{imp} is of type {type(imp)} and cannot be imported.')
try:
imported_tmp = import_module(imp)
except ImportError:
if allow_failed_imports:
warnings.warn(f'{imp} failed to import and is ignored.',
UserWarning)
imported_tmp = None
else:
raise ImportError
imported.append(imported_tmp)
if single_import:
imported = imported[0]
return imported
def iter_cast(inputs, dst_type, return_type=None):
"""Cast elements of an iterable object into some type.
Args:
inputs (Iterable): The input object.
dst_type (type): Destination type.
return_type (type, optional): If specified, the output object will be
converted to this type, otherwise an iterator.
Returns:
iterator or specified type: The converted object.
"""
if not isinstance(inputs, abc.Iterable):
raise TypeError('inputs must be an iterable object')
if not isinstance(dst_type, type):
raise TypeError('"dst_type" must be a valid type')
out_iterable = map(dst_type, inputs)
if return_type is None:
return out_iterable
else:
return return_type(out_iterable)
def list_cast(inputs, dst_type):
"""Cast elements of an iterable object into a list of some type.
A partial method of :func:`iter_cast`.
"""
return iter_cast(inputs, dst_type, return_type=list)
def tuple_cast(inputs, dst_type):
"""Cast elements of an iterable object into a tuple of some type.
A partial method of :func:`iter_cast`.
"""
return iter_cast(inputs, dst_type, return_type=tuple)
def is_seq_of(seq, expected_type, seq_type=None):
"""Check whether it is a sequence of some type.
Args:
seq (Sequence): The sequence to be checked.
expected_type (type): Expected type of sequence items.
seq_type (type, optional): Expected sequence type.
Returns:
bool: Whether the sequence is valid.
"""
if seq_type is None:
exp_seq_type = abc.Sequence
else:
assert isinstance(seq_type, type)
exp_seq_type = seq_type
if not isinstance(seq, exp_seq_type):
return False
for item in seq:
if not isinstance(item, expected_type):
return False
return True
def is_list_of(seq, expected_type):
"""Check whether it is a list of some type.
A partial method of :func:`is_seq_of`.
"""
return is_seq_of(seq, expected_type, seq_type=list)
def is_tuple_of(seq, expected_type):
"""Check whether it is a tuple of some type.
A partial method of :func:`is_seq_of`.
"""
return is_seq_of(seq, expected_type, seq_type=tuple)
def slice_list(in_list, lens):
"""Slice a list into several sub lists by a list of given length.
Args:
in_list (list): The list to be sliced.
lens(int or list): The expected length of each out list.
Returns:
list: A list of sliced list.
"""
if isinstance(lens, int):
assert len(in_list) % lens == 0
lens = [lens] * int(len(in_list) / lens)
if not isinstance(lens, list):
raise TypeError('"indices" must be an integer or a list of integers')
elif sum(lens) != len(in_list):
raise ValueError('sum of lens and list length does not '
f'match: {sum(lens)} != {len(in_list)}')
out_list = []
idx = 0
for i in range(len(lens)):
out_list.append(in_list[idx:idx + lens[i]])
idx += lens[i]
return out_list
def concat_list(in_list):
"""Concatenate a list of list into a single list.
Args:
in_list (list): The list of list to be merged.
Returns:
list: The concatenated flat list.
"""
return list(itertools.chain(*in_list))
def check_prerequisites(
prerequisites,
checker,
msg_tmpl='Prerequisites "{}" are required in method "{}" but not '
'found, please install them first.'): # yapf: disable
"""A decorator factory to check if prerequisites are satisfied.
Args:
prerequisites (str of list[str]): Prerequisites to be checked.
checker (callable): The checker method that returns True if a
prerequisite is meet, False otherwise.
msg_tmpl (str): The message template with two variables.
Returns:
decorator: A specific decorator.
"""
def wrap(func):
@functools.wraps(func)
def wrapped_func(*args, **kwargs):
requirements = [prerequisites] if isinstance(
prerequisites, str) else prerequisites
missing = []
for item in requirements:
if not checker(item):
missing.append(item)
if missing:
print(msg_tmpl.format(', '.join(missing), func.__name__))
raise RuntimeError('Prerequisites not meet.')
else:
return func(*args, **kwargs)
return wrapped_func
return wrap
def _check_py_package(package):
try:
import_module(package)
except ImportError:
return False
else:
return True
def _check_executable(cmd):
if subprocess.call(f'which {cmd}', shell=True) != 0:
return False
else:
return True
def requires_package(prerequisites):
"""A decorator to check if some python packages are installed.
Example:
>>> @requires_package('numpy')
>>> func(arg1, args):
>>> return numpy.zeros(1)
array([0.])
>>> @requires_package(['numpy', 'non_package'])
>>> func(arg1, args):
>>> return numpy.zeros(1)
ImportError
"""
return check_prerequisites(prerequisites, checker=_check_py_package)
def requires_executable(prerequisites):
"""A decorator to check if some executable files are installed.
Example:
>>> @requires_executable('ffmpeg')
>>> func(arg1, args):
>>> print(1)
1
"""
return check_prerequisites(prerequisites, checker=_check_executable)
def deprecated_api_warning(name_dict, cls_name=None):
"""A decorator to check if some arguments are deprecate and try to replace
deprecate src_arg_name to dst_arg_name.
Args:
name_dict(dict):
key (str): Deprecate argument names.
val (str): Expected argument names.
Returns:
func: New function.
"""
def api_warning_wrapper(old_func):
@functools.wraps(old_func)
def new_func(*args, **kwargs):
# get the arg spec of the decorated method
args_info = getfullargspec(old_func)
# get name of the function
func_name = old_func.__name__
if cls_name is not None:
func_name = f'{cls_name}.{func_name}'
if args:
arg_names = args_info.args[:len(args)]
for src_arg_name, dst_arg_name in name_dict.items():
if src_arg_name in arg_names:
warnings.warn(
f'"{src_arg_name}" is deprecated in '
f'`{func_name}`, please use "{dst_arg_name}" '
'instead', DeprecationWarning)
arg_names[arg_names.index(src_arg_name)] = dst_arg_name
if kwargs:
for src_arg_name, dst_arg_name in name_dict.items():
if src_arg_name in kwargs:
assert dst_arg_name not in kwargs, (
f'The expected behavior is to replace '
f'the deprecated key `{src_arg_name}` to '
f'new key `{dst_arg_name}`, but got them '
f'in the arguments at the same time, which '
f'is confusing. `{src_arg_name} will be '
f'deprecated in the future, please '
f'use `{dst_arg_name}` instead.')
warnings.warn(
f'"{src_arg_name}" is deprecated in '
f'`{func_name}`, please use "{dst_arg_name}" '
'instead', DeprecationWarning)
kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
# apply converted arguments to the decorated method
output = old_func(*args, **kwargs)
return output
return new_func
return api_warning_wrapper
def is_method_overridden(method, base_class, derived_class):
"""Check if a method of base class is overridden in derived class.
Args:
method (str): the method name to check.
base_class (type): the class of the base class.
derived_class (type | Any): the class or instance of the derived class.
"""
assert isinstance(base_class, type), \
"base_class doesn't accept instance, Please pass class instead."
if not isinstance(derived_class, type):
derived_class = derived_class.__class__
base_method = getattr(base_class, method)
derived_method = getattr(derived_class, method)
return derived_method != base_method
def has_method(obj: object, method: str) -> bool:
"""Check whether the object has a method.
Args:
method (str): The method name to check.
obj (object): The object to check.
Returns:
bool: True if the object has the method else False.
"""
return hasattr(obj, method) and callable(getattr(obj, method))
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os import os
from .parrots_wrapper import TORCH_VERSION from mmengine.utils.parrots_wrapper import TORCH_VERSION
parrots_jit_option = os.getenv('PARROTS_JIT_OPTION') parrots_jit_option = os.getenv('PARROTS_JIT_OPTION')
......
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
import torch
TORCH_VERSION = torch.__version__
def is_cuda_available() -> bool:
return torch.cuda.is_available()
IS_CUDA_AVAILABLE = is_cuda_available()
def is_rocm_pytorch() -> bool:
is_rocm = False
if TORCH_VERSION != 'parrots':
try:
from torch.utils.cpp_extension import ROCM_HOME
is_rocm = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False
except ImportError:
pass
return is_rocm
def _get_cuda_home():
if TORCH_VERSION == 'parrots':
from parrots.utils.build_extension import CUDA_HOME
else:
if is_rocm_pytorch():
from torch.utils.cpp_extension import ROCM_HOME
CUDA_HOME = ROCM_HOME
else:
from torch.utils.cpp_extension import CUDA_HOME
return CUDA_HOME
def get_build_config():
if TORCH_VERSION == 'parrots':
from parrots.config import get_build_info
return get_build_info()
else:
return torch.__config__.show()
def _get_conv():
if TORCH_VERSION == 'parrots':
from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin
else:
from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
return _ConvNd, _ConvTransposeMixin
def _get_dataloader():
if TORCH_VERSION == 'parrots':
from torch.utils.data import DataLoader, PoolDataLoader
else:
from torch.utils.data import DataLoader
PoolDataLoader = DataLoader
return DataLoader, PoolDataLoader
def _get_extension():
if TORCH_VERSION == 'parrots':
from parrots.utils.build_extension import BuildExtension, Extension
CppExtension = partial(Extension, cuda=False)
CUDAExtension = partial(Extension, cuda=True)
else:
from torch.utils.cpp_extension import (BuildExtension, CppExtension,
CUDAExtension)
return BuildExtension, CppExtension, CUDAExtension
def _get_pool():
if TORCH_VERSION == 'parrots':
from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd,
_MaxPoolNd)
else:
from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd,
_MaxPoolNd)
return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd
def _get_norm():
if TORCH_VERSION == 'parrots':
from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm
SyncBatchNorm_ = torch.nn.SyncBatchNorm2d
else:
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm
SyncBatchNorm_ = torch.nn.SyncBatchNorm
return _BatchNorm, _InstanceNorm, SyncBatchNorm_
_ConvNd, _ConvTransposeMixin = _get_conv()
DataLoader, PoolDataLoader = _get_dataloader()
BuildExtension, CppExtension, CUDAExtension = _get_extension()
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
class SyncBatchNorm(SyncBatchNorm_): # type: ignore
def _check_input_dim(self, input):
if TORCH_VERSION == 'parrots':
if input.dim() < 2:
raise ValueError(
f'expected at least 2D input (got {input.dim()}D input)')
else:
super()._check_input_dim(input)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment