Unverified Commit 27d1d7fe authored by tripleMu's avatar tripleMu Committed by GitHub
Browse files

Add type hints for mmcv/cnn/utils (#1996)



* Add typehint in mmcv/cnn/utils/*

* Update mmcv/cnn/utils/flops_counter.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/cnn/utils/weight_init.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Fix bugs

* Fix

* Fix2

* Update mmcv/cnn/utils/weight_init.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/cnn/utils/flops_counter.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/cnn/utils/weight_init.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/cnn/utils/weight_init.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Fix

* minor fix

* line is too long.

* Update mmcv/cnn/utils/weight_init.py
Co-authored-by: default avatarJiazhen Wang <47851024+teamwong111@users.noreply.github.com>

* Fix

* Update mmcv/cnn/utils/weight_init.py

* fix default value of float type hint

* fix default value of float type hint

* fix default value of float type hint

* fix default value of float type hint

* fix default value of float type hint

* fix default value of float type hint

* fix default value of float type hint

* Fix

* minor refinement

* replace list with tuple for input
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarJiazhen Wang <47851024+teamwong111@users.noreply.github.com>
Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: default avatarzhouzaida <zhouzaida@163.com>
parent 0230fc3b
......@@ -26,6 +26,7 @@
import sys
import warnings
from functools import partial
from typing import Any, Callable, Dict, Optional, TextIO, Tuple
import numpy as np
import torch
......@@ -34,13 +35,13 @@ import torch.nn as nn
import mmcv
def get_model_complexity_info(model,
input_shape,
print_per_layer_stat=True,
as_strings=True,
input_constructor=None,
flush=False,
ost=sys.stdout):
def get_model_complexity_info(model: nn.Module,
input_shape: tuple,
print_per_layer_stat: bool = True,
as_strings: bool = True,
input_constructor: Optional[Callable] = None,
flush: bool = False,
ost: TextIO = sys.stdout) -> tuple:
"""Get complexity information of a model.
This method can calculate FLOPs and parameter counts of a model with
......@@ -116,7 +117,9 @@ def get_model_complexity_info(model,
return flops_count, params_count
def flops_to_string(flops, units='GFLOPs', precision=2):
def flops_to_string(flops: float,
units: Optional[str] = 'GFLOPs',
precision: int = 2) -> str:
"""Convert FLOPs number into a string.
Note that Here we take a multiply-add counts as one FLOP.
......@@ -159,7 +162,9 @@ def flops_to_string(flops, units='GFLOPs', precision=2):
return str(flops) + ' FLOPs'
def params_to_string(num_params, units=None, precision=2):
def params_to_string(num_params: float,
units: Optional[str] = None,
precision: int = 2) -> str:
"""Convert parameter number into a string.
Args:
......@@ -196,13 +201,13 @@ def params_to_string(num_params, units=None, precision=2):
return str(num_params)
def print_model_with_flops(model,
total_flops,
total_params,
units='GFLOPs',
precision=3,
ost=sys.stdout,
flush=False):
def print_model_with_flops(model: nn.Module,
total_flops: float,
total_params: float,
units: Optional[str] = 'GFLOPs',
precision: int = 3,
ost: TextIO = sys.stdout,
flush: bool = False) -> None:
"""Print a model with FLOPs for each layer.
Args:
......@@ -305,7 +310,7 @@ def print_model_with_flops(model,
model.apply(del_extra_repr)
def get_model_parameters_number(model):
def get_model_parameters_number(model: nn.Module) -> float:
"""Calculate parameter number of a model.
Args:
......@@ -318,16 +323,16 @@ def get_model_parameters_number(model):
return num_params
def add_flops_counting_methods(net_main_module):
def add_flops_counting_methods(net_main_module: nn.Module) -> nn.Module:
# adding additional methods to the existing module object,
# this is done this way so that each function has access to self object
net_main_module.start_flops_count = start_flops_count.__get__(
net_main_module.start_flops_count = start_flops_count.__get__( # type: ignore # noqa E501
net_main_module)
net_main_module.stop_flops_count = stop_flops_count.__get__(
net_main_module.stop_flops_count = stop_flops_count.__get__( # type: ignore # noqa E501
net_main_module)
net_main_module.reset_flops_count = reset_flops_count.__get__(
net_main_module.reset_flops_count = reset_flops_count.__get__( # type: ignore # noqa E501
net_main_module)
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # noqa: E501
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # type: ignore # noqa E501
net_main_module)
net_main_module.reset_flops_count()
......@@ -335,7 +340,7 @@ def add_flops_counting_methods(net_main_module):
return net_main_module
def compute_average_flops_cost(self):
def compute_average_flops_cost(self) -> Tuple[float, float]:
"""Compute average FLOPs cost.
A method to compute average FLOPs cost, which will be available after
......@@ -353,7 +358,7 @@ def compute_average_flops_cost(self):
return flops_sum / batches_count, params_sum
def start_flops_count(self):
def start_flops_count(self) -> None:
"""Activate the computation of mean flops consumption per image.
A method to activate the computation of mean flops consumption per image.
......@@ -362,7 +367,7 @@ def start_flops_count(self):
"""
add_batch_counter_hook_function(self)
def add_flops_counter_hook_function(module):
def add_flops_counter_hook_function(module: nn.Module) -> None:
if is_supported_instance(module):
if hasattr(module, '__flops_handle__'):
return
......@@ -376,7 +381,7 @@ def start_flops_count(self):
self.apply(partial(add_flops_counter_hook_function))
def stop_flops_count(self):
def stop_flops_count(self) -> None:
"""Stop computing the mean flops consumption per image.
A method to stop computing the mean flops consumption per image, which will
......@@ -387,7 +392,7 @@ def stop_flops_count(self):
self.apply(remove_flops_counter_hook_function)
def reset_flops_count(self):
def reset_flops_count(self) -> None:
"""Reset statistics computed so far.
A method to Reset computed statistics, which will be available after
......@@ -398,11 +403,13 @@ def reset_flops_count(self):
# ---- Internal functions
def empty_flops_counter_hook(module, input, output):
def empty_flops_counter_hook(module: nn.Module, input: tuple,
output: Any) -> None:
module.__flops__ += 0
def upsample_flops_counter_hook(module, input, output):
def upsample_flops_counter_hook(module: nn.Module, input: tuple,
output: torch.Tensor) -> None:
output_size = output[0]
batch_size = output_size.shape[0]
output_elements_count = batch_size
......@@ -411,39 +418,38 @@ def upsample_flops_counter_hook(module, input, output):
module.__flops__ += int(output_elements_count)
def relu_flops_counter_hook(module, input, output):
def relu_flops_counter_hook(module: nn.Module, input: tuple,
output: torch.Tensor) -> None:
active_elements_count = output.numel()
module.__flops__ += int(active_elements_count)
def linear_flops_counter_hook(module, input, output):
input = input[0]
def linear_flops_counter_hook(module: nn.Module, input: tuple,
output: torch.Tensor) -> None:
output_last_dim = output.shape[
-1] # pytorch checks dimensions, so here we don't care much
module.__flops__ += int(np.prod(input.shape) * output_last_dim)
module.__flops__ += int(np.prod(input[0].shape) * output_last_dim)
def pool_flops_counter_hook(module, input, output):
input = input[0]
module.__flops__ += int(np.prod(input.shape))
def pool_flops_counter_hook(module: nn.Module, input: tuple,
output: torch.Tensor) -> None:
module.__flops__ += int(np.prod(input[0].shape))
def norm_flops_counter_hook(module, input, output):
input = input[0]
batch_flops = np.prod(input.shape)
def norm_flops_counter_hook(module: nn.Module, input: tuple,
output: torch.Tensor) -> None:
batch_flops = np.prod(input[0].shape)
if (getattr(module, 'affine', False)
or getattr(module, 'elementwise_affine', False)):
batch_flops *= 2
module.__flops__ += int(batch_flops)
def deconv_flops_counter_hook(conv_module, input, output):
def deconv_flops_counter_hook(conv_module: nn.Module, input: tuple,
output: torch.Tensor) -> None:
# Can have multiple inputs, getting the first one
input = input[0]
batch_size = input.shape[0]
input_height, input_width = input.shape[2:]
batch_size = input[0].shape[0]
input_height, input_width = input[0].shape[2:]
kernel_height, kernel_width = conv_module.kernel_size
in_channels = conv_module.in_channels
......@@ -465,11 +471,10 @@ def deconv_flops_counter_hook(conv_module, input, output):
conv_module.__flops__ += int(overall_flops)
def conv_flops_counter_hook(conv_module, input, output):
def conv_flops_counter_hook(conv_module: nn.Module, input: tuple,
output: torch.Tensor) -> None:
# Can have multiple inputs, getting the first one
input = input[0]
batch_size = input.shape[0]
batch_size = input[0].shape[0]
output_dims = list(output.shape[2:])
kernel_dims = list(conv_module.kernel_size)
......@@ -496,24 +501,23 @@ def conv_flops_counter_hook(conv_module, input, output):
conv_module.__flops__ += int(overall_flops)
def batch_counter_hook(module, input, output):
def batch_counter_hook(module: nn.Module, input: tuple, output: Any) -> None:
batch_size = 1
if len(input) > 0:
# Can have multiple inputs, getting the first one
input = input[0]
batch_size = len(input)
batch_size = len(input[0])
else:
warnings.warn('No positional inputs found for a module, '
'assuming batch size is 1.')
module.__batch_counter__ += batch_size
def add_batch_counter_variables_or_reset(module):
def add_batch_counter_variables_or_reset(module: nn.Module) -> None:
module.__batch_counter__ = 0
def add_batch_counter_hook_function(module):
def add_batch_counter_hook_function(module: nn.Module) -> None:
if hasattr(module, '__batch_counter_handle__'):
return
......@@ -521,13 +525,13 @@ def add_batch_counter_hook_function(module):
module.__batch_counter_handle__ = handle
def remove_batch_counter_hook_function(module):
def remove_batch_counter_hook_function(module: nn.Module) -> None:
if hasattr(module, '__batch_counter_handle__'):
module.__batch_counter_handle__.remove()
del module.__batch_counter_handle__
def add_flops_counter_variable_or_reset(module):
def add_flops_counter_variable_or_reset(module: nn.Module) -> None:
if is_supported_instance(module):
if hasattr(module, '__flops__') or hasattr(module, '__params__'):
warnings.warn('variables __flops__ or __params__ are already '
......@@ -537,20 +541,20 @@ def add_flops_counter_variable_or_reset(module):
module.__params__ = get_model_parameters_number(module)
def is_supported_instance(module):
def is_supported_instance(module: nn.Module) -> bool:
if type(module) in get_modules_mapping():
return True
return False
def remove_flops_counter_hook_function(module):
def remove_flops_counter_hook_function(module: nn.Module) -> None:
if is_supported_instance(module):
if hasattr(module, '__flops_handle__'):
module.__flops_handle__.remove()
del module.__flops_handle__
def get_modules_mapping():
def get_modules_mapping() -> Dict:
return {
# convolutions
nn.Conv1d: conv_flops_counter_hook,
......
......@@ -3,7 +3,7 @@ import torch
import torch.nn as nn
def _fuse_conv_bn(conv, bn):
def _fuse_conv_bn(conv: nn.Module, bn: nn.Module) -> nn.Module:
"""Fuse conv and bn into one module.
Args:
......@@ -24,7 +24,7 @@ def _fuse_conv_bn(conv, bn):
return conv
def fuse_conv_bn(module):
def fuse_conv_bn(module: nn.Module) -> nn.Module:
"""Recursively fuse conv and bn in a module.
During inference, the functionary of batch norm layers is turned off
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import mmcv
class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
class _BatchNormXd(nn.modules.batchnorm._BatchNorm):
"""A general BatchNorm layer without input dimension check.
Reproduced from @kapily's work:
......@@ -15,11 +16,11 @@ class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
SyncBatchNorm.
"""
def _check_input_dim(self, input):
def _check_input_dim(self, input: torch.Tensor):
return
def revert_sync_batchnorm(module):
def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
"""Helper function to convert all `SyncBatchNorm` (SyncBN) and
`mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
`BatchNormXd` layers.
......
......@@ -2,6 +2,7 @@
import copy
import math
import warnings
from typing import Dict, List, Optional, Union
import numpy as np
import torch
......@@ -13,7 +14,7 @@ from mmcv.utils import Registry, build_from_cfg, get_logger, print_log
INITIALIZERS = Registry('initializer')
def update_init_info(module, init_info):
def update_init_info(module: nn.Module, init_info: str) -> None:
"""Update the `_params_init_info` in the module if the value of parameters
are changed.
......@@ -45,14 +46,17 @@ def update_init_info(module, init_info):
module._params_init_info[param]['tmp_mean_value'] = mean_value
def constant_init(module, val, bias=0):
def constant_init(module: nn.Module, val: float, bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def xavier_init(module, gain=1, bias=0, distribution='normal'):
def xavier_init(module: nn.Module,
gain: float = 1,
bias: float = 0,
distribution: str = 'normal') -> None:
assert distribution in ['uniform', 'normal']
if hasattr(module, 'weight') and module.weight is not None:
if distribution == 'uniform':
......@@ -63,7 +67,10 @@ def xavier_init(module, gain=1, bias=0, distribution='normal'):
nn.init.constant_(module.bias, bias)
def normal_init(module, mean=0, std=1, bias=0):
def normal_init(module: nn.Module,
mean: float = 0,
std: float = 1,
bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None:
nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias') and module.bias is not None:
......@@ -82,19 +89,22 @@ def trunc_normal_init(module: nn.Module,
nn.init.constant_(module.bias, bias) # type: ignore
def uniform_init(module, a=0, b=1, bias=0):
def uniform_init(module: nn.Module,
a: float = 0,
b: float = 1,
bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None:
nn.init.uniform_(module.weight, a, b)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def kaiming_init(module,
a=0,
mode='fan_out',
nonlinearity='relu',
bias=0,
distribution='normal'):
def kaiming_init(module: nn.Module,
a: float = 0,
mode: str = 'fan_out',
nonlinearity: str = 'relu',
bias: float = 0,
distribution: str = 'normal') -> None:
assert distribution in ['uniform', 'normal']
if hasattr(module, 'weight') and module.weight is not None:
if distribution == 'uniform':
......@@ -107,7 +117,7 @@ def kaiming_init(module,
nn.init.constant_(module.bias, bias)
def caffe2_xavier_init(module, bias=0):
def caffe2_xavier_init(module: nn.Module, bias: float = 0) -> None:
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
kaiming_init(
......@@ -119,19 +129,23 @@ def caffe2_xavier_init(module, bias=0):
distribution='uniform')
def bias_init_with_prob(prior_prob):
def bias_init_with_prob(prior_prob: float) -> float:
"""initialize conv/fc bias value according to a given probability value."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init
def _get_bases_name(m):
def _get_bases_name(m: nn.Module) -> List[str]:
return [b.__name__ for b in m.__class__.__bases__]
class BaseInit:
def __init__(self, *, bias=0, bias_prob=None, layer=None):
def __init__(self,
*,
bias: float = 0,
bias_prob: Optional[float] = None,
layer: Union[str, List, None] = None):
self.wholemodule = False
if not isinstance(bias, (int, float)):
raise TypeError(f'bias must be a number, but got a {type(bias)}')
......@@ -154,7 +168,7 @@ class BaseInit:
self.bias = bias
self.layer = [layer] if isinstance(layer, str) else layer
def _get_init_info(self):
def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}, bias={self.bias}'
return info
......@@ -172,11 +186,11 @@ class ConstantInit(BaseInit):
Defaults to None.
"""
def __init__(self, val, **kwargs):
def __init__(self, val: Union[int, float], **kwargs):
super().__init__(**kwargs)
self.val = val
def __call__(self, module):
def __call__(self, module: nn.Module) -> None:
def init(m):
if self.wholemodule:
......@@ -191,7 +205,7 @@ class ConstantInit(BaseInit):
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
return info
......@@ -214,12 +228,15 @@ class XavierInit(BaseInit):
Defaults to None.
"""
def __init__(self, gain=1, distribution='normal', **kwargs):
def __init__(self,
gain: float = 1,
distribution: str = 'normal',
**kwargs):
super().__init__(**kwargs)
self.gain = gain
self.distribution = distribution
def __call__(self, module):
def __call__(self, module: nn.Module) -> None:
def init(m):
if self.wholemodule:
......@@ -234,7 +251,7 @@ class XavierInit(BaseInit):
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: gain={self.gain}, ' \
f'distribution={self.distribution}, bias={self.bias}'
return info
......@@ -257,12 +274,12 @@ class NormalInit(BaseInit):
"""
def __init__(self, mean=0, std=1, **kwargs):
def __init__(self, mean: float = 0, std: float = 1, **kwargs):
super().__init__(**kwargs)
self.mean = mean
self.std = std
def __call__(self, module):
def __call__(self, module: nn.Module) -> None:
def init(m):
if self.wholemodule:
......@@ -277,7 +294,7 @@ class NormalInit(BaseInit):
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: mean={self.mean},' \
f' std={self.std}, bias={self.bias}'
return info
......@@ -355,12 +372,12 @@ class UniformInit(BaseInit):
Defaults to None.
"""
def __init__(self, a=0, b=1, **kwargs):
def __init__(self, a: float = 0., b: float = 1., **kwargs):
super().__init__(**kwargs)
self.a = a
self.b = b
def __call__(self, module):
def __call__(self, module: nn.Module) -> None:
def init(m):
if self.wholemodule:
......@@ -375,7 +392,7 @@ class UniformInit(BaseInit):
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: a={self.a},' \
f' b={self.b}, bias={self.bias}'
return info
......@@ -409,10 +426,10 @@ class KaimingInit(BaseInit):
"""
def __init__(self,
a=0,
mode='fan_out',
nonlinearity='relu',
distribution='normal',
a: float = 0,
mode: str = 'fan_out',
nonlinearity: str = 'relu',
distribution: str = 'normal',
**kwargs):
super().__init__(**kwargs)
self.a = a
......@@ -420,7 +437,7 @@ class KaimingInit(BaseInit):
self.nonlinearity = nonlinearity
self.distribution = distribution
def __call__(self, module):
def __call__(self, module: nn.Module) -> None:
def init(m):
if self.wholemodule:
......@@ -437,7 +454,7 @@ class KaimingInit(BaseInit):
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
f'nonlinearity={self.nonlinearity}, ' \
f'distribution ={self.distribution}, bias={self.bias}'
......@@ -456,7 +473,7 @@ class Caffe2XavierInit(KaimingInit):
distribution='uniform',
**kwargs)
def __call__(self, module):
def __call__(self, module: nn.Module) -> None:
super().__call__(module)
......@@ -475,12 +492,15 @@ class PretrainedInit:
map_location (str): map tensors into proper locations.
"""
def __init__(self, checkpoint, prefix=None, map_location=None):
def __init__(self,
checkpoint: str,
prefix: Optional[str] = None,
map_location: Optional[str] = None):
self.checkpoint = checkpoint
self.prefix = prefix
self.map_location = map_location
def __call__(self, module):
def __call__(self, module: nn.Module) -> None:
from mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
load_state_dict)
logger = get_logger('mmcv')
......@@ -503,12 +523,14 @@ class PretrainedInit:
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: load from {self.checkpoint}'
return info
def _initialize(module, cfg, wholemodule=False):
def _initialize(module: nn.Module,
cfg: Dict,
wholemodule: bool = False) -> None:
func = build_from_cfg(cfg, INITIALIZERS)
# wholemodule flag is for override mode, there is no layer key in override
# and initializer will give init values for the whole module with the name
......@@ -517,7 +539,8 @@ def _initialize(module, cfg, wholemodule=False):
func(module)
def _initialize_override(module, override, cfg):
def _initialize_override(module: nn.Module, override: Union[Dict, List],
cfg: Dict) -> None:
if not isinstance(override, (dict, list)):
raise TypeError(f'override must be a dict or a list of dict, \
but got {type(override)}')
......@@ -547,7 +570,7 @@ def _initialize_override(module, override, cfg):
f'but init_cfg is {cp_override}.')
def initialize(module, init_cfg):
def initialize(module: nn.Module, init_cfg: Union[Dict, List[dict]]) -> None:
r"""Initialize a module.
Args:
......
......@@ -40,7 +40,7 @@ def _get_mmcv_home():
def load_state_dict(module: torch.nn.Module,
state_dict: OrderedDict,
state_dict: Union[dict, OrderedDict],
strict: bool = False,
logger: Optional[logging.Logger] = None) -> None:
"""Load state_dict to a module.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment