Commit fdeee889 authored by limm's avatar limm
Browse files

release v1.6.1 of mmcv

parent df465820
......@@ -12,16 +12,16 @@ class LogBuffer:
self.output = OrderedDict()
self.ready = False
def clear(self):
def clear(self) -> None:
self.val_history.clear()
self.n_history.clear()
self.clear_output()
def clear_output(self):
def clear_output(self) -> None:
self.output.clear()
self.ready = False
def update(self, vars, count=1):
def update(self, vars: dict, count: int = 1) -> None:
assert isinstance(vars, dict)
for key, var in vars.items():
if key not in self.val_history:
......@@ -30,7 +30,7 @@ class LogBuffer:
self.val_history[key].append(var)
self.n_history[key].append(count)
def average(self, n=0):
def average(self, n: int = 0) -> None:
"""Average latest n values or all values."""
assert n >= 0
for key in self.val_history:
......
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
from typing import Dict, List
import torch
......@@ -10,7 +11,7 @@ OPTIMIZERS = Registry('optimizer')
OPTIMIZER_BUILDERS = Registry('optimizer builder')
def register_torch_optimizers():
def register_torch_optimizers() -> List:
torch_optimizers = []
for module_name in dir(torch.optim):
if module_name.startswith('__'):
......@@ -26,11 +27,11 @@ def register_torch_optimizers():
TORCH_OPTIMIZERS = register_torch_optimizers()
def build_optimizer_constructor(cfg):
def build_optimizer_constructor(cfg: Dict):
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
def build_optimizer(model, cfg):
def build_optimizer(model, cfg: Dict):
optimizer_cfg = copy.deepcopy(cfg)
constructor_type = optimizer_cfg.pop('constructor',
'DefaultOptimizerConstructor')
......
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from torch.nn import GroupNorm, LayerNorm
from mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of
......@@ -46,16 +48,17 @@ class DefaultOptimizerConstructor:
would not be added into optimizer. Default: False.
Note:
1. If the option ``dcn_offset_lr_mult`` is used, the constructor will
override the effect of ``bias_lr_mult`` in the bias of offset
layer. So be careful when using both ``bias_lr_mult`` and
``dcn_offset_lr_mult``. If you wish to apply both of them to the
offset layer in deformable convs, set ``dcn_offset_lr_mult``
to the original ``dcn_offset_lr_mult`` * ``bias_lr_mult``.
override the effect of ``bias_lr_mult`` in the bias of offset layer.
So be careful when using both ``bias_lr_mult`` and
``dcn_offset_lr_mult``. If you wish to apply both of them to the offset
layer in deformable convs, set ``dcn_offset_lr_mult`` to the original
``dcn_offset_lr_mult`` * ``bias_lr_mult``.
2. If the option ``dcn_offset_lr_mult`` is used, the constructor will
apply it to all the DCN layers in the model. So be careful when
the model contains multiple DCN layers in places other than
backbone.
apply it to all the DCN layers in the model. So be careful when the
model contains multiple DCN layers in places other than backbone.
Args:
model (:obj:`nn.Module`): The model with parameters to be optimized.
......@@ -83,7 +86,7 @@ class DefaultOptimizerConstructor:
>>> # assume model have attribute model.backbone and model.cls_head
>>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
>>> paramwise_cfg = dict(custom_keys={
'.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
'backbone': dict(lr_mult=0.1, decay_mult=0.9)})
>>> optim_builder = DefaultOptimizerConstructor(
>>> optimizer_cfg, paramwise_cfg)
>>> optimizer = optim_builder(model)
......@@ -92,7 +95,9 @@ class DefaultOptimizerConstructor:
>>> # model.cls_head is (0.01, 0.95).
"""
def __init__(self, optimizer_cfg, paramwise_cfg=None):
def __init__(self,
optimizer_cfg: Dict,
paramwise_cfg: Optional[Dict] = None):
if not isinstance(optimizer_cfg, dict):
raise TypeError('optimizer_cfg should be a dict',
f'but got {type(optimizer_cfg)}')
......@@ -102,7 +107,7 @@ class DefaultOptimizerConstructor:
self.base_wd = optimizer_cfg.get('weight_decay', None)
self._validate_cfg()
def _validate_cfg(self):
def _validate_cfg(self) -> None:
if not isinstance(self.paramwise_cfg, dict):
raise TypeError('paramwise_cfg should be None or a dict, '
f'but got {type(self.paramwise_cfg)}')
......@@ -125,7 +130,7 @@ class DefaultOptimizerConstructor:
if self.base_wd is None:
raise ValueError('base_wd should not be None')
def _is_in(self, param_group, param_group_list):
def _is_in(self, param_group: Dict, param_group_list: List) -> bool:
assert is_list_of(param_group_list, dict)
param = set(param_group['params'])
param_set = set()
......@@ -134,7 +139,11 @@ class DefaultOptimizerConstructor:
return not param.isdisjoint(param_set)
def add_params(self, params, module, prefix='', is_dcn_module=None):
def add_params(self,
params: List[Dict],
module: nn.Module,
prefix: str = '',
is_dcn_module: Union[int, float, None] = None) -> None:
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
......@@ -231,7 +240,7 @@ class DefaultOptimizerConstructor:
prefix=child_prefix,
is_dcn_module=is_dcn_module)
def __call__(self, model):
def __call__(self, model: nn.Module):
if hasattr(model, 'module'):
model = model.module
......@@ -242,7 +251,7 @@ class DefaultOptimizerConstructor:
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
# set param-wise lr and weight decay recursively
params = []
params: List[Dict] = []
self.add_params(params, model)
optimizer_cfg['params'] = params
......
# Copyright (c) OpenMMLab. All rights reserved.
from enum import Enum
from typing import Union
class Priority(Enum):
......@@ -39,7 +40,7 @@ class Priority(Enum):
LOWEST = 100
def get_priority(priority):
def get_priority(priority: Union[int, str, Priority]) -> int:
"""Get priority value.
Args:
......
......@@ -6,6 +6,8 @@ import time
import warnings
from getpass import getuser
from socket import gethostname
from types import ModuleType
from typing import Optional
import numpy as np
import torch
......@@ -13,7 +15,7 @@ import torch
import mmcv
def get_host_info():
def get_host_info() -> str:
"""Get hostname and username.
Return empty string if exception raised, e.g. ``getpass.getuser()`` will
......@@ -28,11 +30,13 @@ def get_host_info():
return host
def get_time_str():
def get_time_str() -> str:
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
def obj_from_dict(info, parent=None, default_args=None):
def obj_from_dict(info: dict,
parent: Optional[ModuleType] = None,
default_args: Optional[dict] = None):
"""Initialize an object from dict.
The dict must contain the key "type", which indicates the object type, it
......@@ -67,7 +71,9 @@ def obj_from_dict(info, parent=None, default_args=None):
return obj_type(**args)
def set_random_seed(seed, deterministic=False, use_rank_shift=False):
def set_random_seed(seed: int,
deterministic: bool = False,
use_rank_shift: bool = False) -> None:
"""Set random seed.
Args:
......
......@@ -22,9 +22,9 @@ if is_tensorrt_available():
# load tensorrt plugin lib
load_tensorrt_plugin()
__all__.append([
__all__.extend([
'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper',
'TRTWrapper'
])
__all__.append(['is_tensorrt_plugin_loaded', 'preprocess_onnx'])
__all__.extend(['is_tensorrt_plugin_loaded', 'preprocess_onnx'])
......@@ -2,10 +2,23 @@
import ctypes
import glob
import os
import warnings
def get_tensorrt_op_path():
def get_tensorrt_op_path() -> str:
"""Get TensorRT plugins library path."""
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This function will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
wildcard = os.path.join(
os.path.abspath(os.path.dirname(os.path.dirname(__file__))),
'_ext_trt.*.so')
......@@ -18,18 +31,44 @@ def get_tensorrt_op_path():
plugin_is_loaded = False
def is_tensorrt_plugin_loaded():
def is_tensorrt_plugin_loaded() -> bool:
"""Check if TensorRT plugins library is loaded or not.
Returns:
bool: plugin_is_loaded flag
"""
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This function will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
global plugin_is_loaded
return plugin_is_loaded
def load_tensorrt_plugin():
def load_tensorrt_plugin() -> None:
"""load TensorRT plugins library."""
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This function will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
global plugin_is_loaded
lib_path = get_tensorrt_op_path()
if (not plugin_is_loaded) and os.path.exists(lib_path):
......
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import numpy as np
import onnx
def preprocess_onnx(onnx_model):
def preprocess_onnx(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Modify onnx model to match with TensorRT plugins in mmcv.
There are some conflict between onnx node definition and TensorRT limit.
......@@ -18,6 +21,19 @@ def preprocess_onnx(onnx_model):
Returns:
onnx.ModelProto: Modified onnx model.
"""
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This function will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
graph = onnx_model.graph
nodes = graph.node
initializers = graph.initializer
......
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Union
import onnx
import tensorrt as trt
......@@ -8,12 +9,12 @@ import torch
from .preprocess import preprocess_onnx
def onnx2trt(onnx_model,
opt_shape_dict,
log_level=trt.Logger.ERROR,
fp16_mode=False,
max_workspace_size=0,
device_id=0):
def onnx2trt(onnx_model: Union[str, onnx.ModelProto],
opt_shape_dict: dict,
log_level: trt.ILogger.Severity = trt.Logger.ERROR,
fp16_mode: bool = False,
max_workspace_size: int = 0,
device_id: int = 0) -> trt.ICudaEngine:
"""Convert onnx model to tensorrt engine.
Arguments:
......@@ -40,7 +41,20 @@ def onnx2trt(onnx_model,
>>> device_id=0)
>>> })
"""
device = torch.device('cuda:{}'.format(device_id))
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This function will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
device = torch.device(f'cuda:{device_id}')
# create builder and network
logger = trt.Logger(log_level)
builder = trt.Builder(logger)
......@@ -87,18 +101,31 @@ def onnx2trt(onnx_model,
return engine
def save_trt_engine(engine, path):
def save_trt_engine(engine: trt.ICudaEngine, path: str) -> None:
"""Serialize TensorRT engine to disk.
Arguments:
engine (tensorrt.ICudaEngine): TensorRT engine to serialize
path (str): disk path to write the engine
"""
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This function will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
with open(path, mode='wb') as f:
f.write(bytearray(engine.serialize()))
def load_trt_engine(path):
def load_trt_engine(path: str) -> trt.ICudaEngine:
"""Deserialize TensorRT engine from disk.
Arguments:
......@@ -107,6 +134,19 @@ def load_trt_engine(path):
Returns:
tensorrt.ICudaEngine: the TensorRT engine loaded from disk
"""
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This function will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
with open(path, mode='rb') as f:
engine_bytes = f.read()
......@@ -114,7 +154,7 @@ def load_trt_engine(path):
return engine
def torch_dtype_from_trt(dtype):
def torch_dtype_from_trt(dtype: trt.DataType) -> Union[torch.dtype, TypeError]:
"""Convert pytorch dtype to TensorRT dtype."""
if dtype == trt.bool:
return torch.bool
......@@ -130,7 +170,8 @@ def torch_dtype_from_trt(dtype):
raise TypeError('%s is not supported by torch' % dtype)
def torch_device_from_trt(device):
def torch_device_from_trt(
device: trt.TensorLocation) -> Union[torch.device, TypeError]:
"""Convert pytorch device to TensorRT device."""
if device == trt.TensorLocation.DEVICE:
return torch.device('cuda')
......@@ -154,7 +195,21 @@ class TRTWrapper(torch.nn.Module):
"""
def __init__(self, engine, input_names=None, output_names=None):
super(TRTWrapper, self).__init__()
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This tool will be deprecated in future. '
msg += blue_text + \
'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
super().__init__()
self.engine = engine
if isinstance(self.engine, str):
self.engine = load_trt_engine(engine)
......@@ -231,5 +286,6 @@ class TRTWraper(TRTWrapper):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn('TRTWraper will be deprecated in'
' future. Please use TRTWrapper instead')
warnings.warn(
'TRTWraper will be deprecated in'
' future. Please use TRTWrapper instead', DeprecationWarning)
......@@ -36,17 +36,26 @@ except ImportError:
'is_method_overridden', 'has_method'
]
else:
from .device_type import (IS_IPU_AVAILABLE, IS_MLU_AVAILABLE,
IS_MPS_AVAILABLE)
from .env import collect_env
from .hub import load_url
from .logging import get_logger, print_log
from .parrots_jit import jit, skip_no_elena
from .parrots_wrapper import (
TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, DataLoader,
PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
_AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm,
_MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home)
# yapf: disable
from .parrots_wrapper import (IS_CUDA_AVAILABLE, TORCH_VERSION,
BuildExtension, CppExtension, CUDAExtension,
DataLoader, PoolDataLoader, SyncBatchNorm,
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
_AvgPoolNd, _BatchNorm, _ConvNd,
_ConvTransposeMixin, _get_cuda_home,
_InstanceNorm, _MaxPoolNd, get_build_config,
is_rocm_pytorch)
# yapf: enable
from .registry import Registry, build_from_cfg
from .seed import worker_init_fn
from .torch_ops import torch_meshgrid
from .trace import is_jit_tracing
from .hub import load_url
__all__ = [
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
......@@ -66,5 +75,7 @@ else:
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
'assert_params_all_zeros', 'check_python_script',
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
'_get_cuda_home', 'load_url', 'has_method'
'_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE',
'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE',
'IS_MPS_AVAILABLE', 'torch_meshgrid'
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment