Unverified Commit 27d1d7fe authored by tripleMu's avatar tripleMu Committed by GitHub
Browse files

Add type hints for mmcv/cnn/utils (#1996)



* Add typehint in mmcv/cnn/utils/*

* Update mmcv/cnn/utils/flops_counter.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/cnn/utils/weight_init.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Fix bugs

* Fix

* Fix2

* Update mmcv/cnn/utils/weight_init.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/cnn/utils/flops_counter.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/cnn/utils/weight_init.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/cnn/utils/weight_init.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Fix

* minor fix

* line is too long.

* Update mmcv/cnn/utils/weight_init.py
Co-authored-by: default avatarJiazhen Wang <47851024+teamwong111@users.noreply.github.com>

* Fix

* Update mmcv/cnn/utils/weight_init.py

* fix default value of float type hint

* fix default value of float type hint

* fix default value of float type hint

* fix default value of float type hint

* fix default value of float type hint

* fix default value of float type hint

* fix default value of float type hint

* Fix

* minor refinement

* replace list with tuple for input
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarJiazhen Wang <47851024+teamwong111@users.noreply.github.com>
Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: default avatarzhouzaida <zhouzaida@163.com>
parent 0230fc3b
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
import sys import sys
import warnings import warnings
from functools import partial from functools import partial
from typing import Any, Callable, Dict, Optional, TextIO, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -34,13 +35,13 @@ import torch.nn as nn ...@@ -34,13 +35,13 @@ import torch.nn as nn
import mmcv import mmcv
def get_model_complexity_info(model, def get_model_complexity_info(model: nn.Module,
input_shape, input_shape: tuple,
print_per_layer_stat=True, print_per_layer_stat: bool = True,
as_strings=True, as_strings: bool = True,
input_constructor=None, input_constructor: Optional[Callable] = None,
flush=False, flush: bool = False,
ost=sys.stdout): ost: TextIO = sys.stdout) -> tuple:
"""Get complexity information of a model. """Get complexity information of a model.
This method can calculate FLOPs and parameter counts of a model with This method can calculate FLOPs and parameter counts of a model with
...@@ -116,7 +117,9 @@ def get_model_complexity_info(model, ...@@ -116,7 +117,9 @@ def get_model_complexity_info(model,
return flops_count, params_count return flops_count, params_count
def flops_to_string(flops, units='GFLOPs', precision=2): def flops_to_string(flops: float,
units: Optional[str] = 'GFLOPs',
precision: int = 2) -> str:
"""Convert FLOPs number into a string. """Convert FLOPs number into a string.
Note that Here we take a multiply-add counts as one FLOP. Note that Here we take a multiply-add counts as one FLOP.
...@@ -159,7 +162,9 @@ def flops_to_string(flops, units='GFLOPs', precision=2): ...@@ -159,7 +162,9 @@ def flops_to_string(flops, units='GFLOPs', precision=2):
return str(flops) + ' FLOPs' return str(flops) + ' FLOPs'
def params_to_string(num_params, units=None, precision=2): def params_to_string(num_params: float,
units: Optional[str] = None,
precision: int = 2) -> str:
"""Convert parameter number into a string. """Convert parameter number into a string.
Args: Args:
...@@ -196,13 +201,13 @@ def params_to_string(num_params, units=None, precision=2): ...@@ -196,13 +201,13 @@ def params_to_string(num_params, units=None, precision=2):
return str(num_params) return str(num_params)
def print_model_with_flops(model, def print_model_with_flops(model: nn.Module,
total_flops, total_flops: float,
total_params, total_params: float,
units='GFLOPs', units: Optional[str] = 'GFLOPs',
precision=3, precision: int = 3,
ost=sys.stdout, ost: TextIO = sys.stdout,
flush=False): flush: bool = False) -> None:
"""Print a model with FLOPs for each layer. """Print a model with FLOPs for each layer.
Args: Args:
...@@ -305,7 +310,7 @@ def print_model_with_flops(model, ...@@ -305,7 +310,7 @@ def print_model_with_flops(model,
model.apply(del_extra_repr) model.apply(del_extra_repr)
def get_model_parameters_number(model): def get_model_parameters_number(model: nn.Module) -> float:
"""Calculate parameter number of a model. """Calculate parameter number of a model.
Args: Args:
...@@ -318,16 +323,16 @@ def get_model_parameters_number(model): ...@@ -318,16 +323,16 @@ def get_model_parameters_number(model):
return num_params return num_params
def add_flops_counting_methods(net_main_module): def add_flops_counting_methods(net_main_module: nn.Module) -> nn.Module:
# adding additional methods to the existing module object, # adding additional methods to the existing module object,
# this is done this way so that each function has access to self object # this is done this way so that each function has access to self object
net_main_module.start_flops_count = start_flops_count.__get__( net_main_module.start_flops_count = start_flops_count.__get__( # type: ignore # noqa E501
net_main_module) net_main_module)
net_main_module.stop_flops_count = stop_flops_count.__get__( net_main_module.stop_flops_count = stop_flops_count.__get__( # type: ignore # noqa E501
net_main_module) net_main_module)
net_main_module.reset_flops_count = reset_flops_count.__get__( net_main_module.reset_flops_count = reset_flops_count.__get__( # type: ignore # noqa E501
net_main_module) net_main_module)
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # noqa: E501 net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # type: ignore # noqa E501
net_main_module) net_main_module)
net_main_module.reset_flops_count() net_main_module.reset_flops_count()
...@@ -335,7 +340,7 @@ def add_flops_counting_methods(net_main_module): ...@@ -335,7 +340,7 @@ def add_flops_counting_methods(net_main_module):
return net_main_module return net_main_module
def compute_average_flops_cost(self): def compute_average_flops_cost(self) -> Tuple[float, float]:
"""Compute average FLOPs cost. """Compute average FLOPs cost.
A method to compute average FLOPs cost, which will be available after A method to compute average FLOPs cost, which will be available after
...@@ -353,7 +358,7 @@ def compute_average_flops_cost(self): ...@@ -353,7 +358,7 @@ def compute_average_flops_cost(self):
return flops_sum / batches_count, params_sum return flops_sum / batches_count, params_sum
def start_flops_count(self): def start_flops_count(self) -> None:
"""Activate the computation of mean flops consumption per image. """Activate the computation of mean flops consumption per image.
A method to activate the computation of mean flops consumption per image. A method to activate the computation of mean flops consumption per image.
...@@ -362,7 +367,7 @@ def start_flops_count(self): ...@@ -362,7 +367,7 @@ def start_flops_count(self):
""" """
add_batch_counter_hook_function(self) add_batch_counter_hook_function(self)
def add_flops_counter_hook_function(module): def add_flops_counter_hook_function(module: nn.Module) -> None:
if is_supported_instance(module): if is_supported_instance(module):
if hasattr(module, '__flops_handle__'): if hasattr(module, '__flops_handle__'):
return return
...@@ -376,7 +381,7 @@ def start_flops_count(self): ...@@ -376,7 +381,7 @@ def start_flops_count(self):
self.apply(partial(add_flops_counter_hook_function)) self.apply(partial(add_flops_counter_hook_function))
def stop_flops_count(self): def stop_flops_count(self) -> None:
"""Stop computing the mean flops consumption per image. """Stop computing the mean flops consumption per image.
A method to stop computing the mean flops consumption per image, which will A method to stop computing the mean flops consumption per image, which will
...@@ -387,7 +392,7 @@ def stop_flops_count(self): ...@@ -387,7 +392,7 @@ def stop_flops_count(self):
self.apply(remove_flops_counter_hook_function) self.apply(remove_flops_counter_hook_function)
def reset_flops_count(self): def reset_flops_count(self) -> None:
"""Reset statistics computed so far. """Reset statistics computed so far.
A method to Reset computed statistics, which will be available after A method to Reset computed statistics, which will be available after
...@@ -398,11 +403,13 @@ def reset_flops_count(self): ...@@ -398,11 +403,13 @@ def reset_flops_count(self):
# ---- Internal functions # ---- Internal functions
def empty_flops_counter_hook(module, input, output): def empty_flops_counter_hook(module: nn.Module, input: tuple,
output: Any) -> None:
module.__flops__ += 0 module.__flops__ += 0
def upsample_flops_counter_hook(module, input, output): def upsample_flops_counter_hook(module: nn.Module, input: tuple,
output: torch.Tensor) -> None:
output_size = output[0] output_size = output[0]
batch_size = output_size.shape[0] batch_size = output_size.shape[0]
output_elements_count = batch_size output_elements_count = batch_size
...@@ -411,39 +418,38 @@ def upsample_flops_counter_hook(module, input, output): ...@@ -411,39 +418,38 @@ def upsample_flops_counter_hook(module, input, output):
module.__flops__ += int(output_elements_count) module.__flops__ += int(output_elements_count)
def relu_flops_counter_hook(module, input, output): def relu_flops_counter_hook(module: nn.Module, input: tuple,
output: torch.Tensor) -> None:
active_elements_count = output.numel() active_elements_count = output.numel()
module.__flops__ += int(active_elements_count) module.__flops__ += int(active_elements_count)
def linear_flops_counter_hook(module, input, output): def linear_flops_counter_hook(module: nn.Module, input: tuple,
input = input[0] output: torch.Tensor) -> None:
output_last_dim = output.shape[ output_last_dim = output.shape[
-1] # pytorch checks dimensions, so here we don't care much -1] # pytorch checks dimensions, so here we don't care much
module.__flops__ += int(np.prod(input.shape) * output_last_dim) module.__flops__ += int(np.prod(input[0].shape) * output_last_dim)
def pool_flops_counter_hook(module, input, output): def pool_flops_counter_hook(module: nn.Module, input: tuple,
input = input[0] output: torch.Tensor) -> None:
module.__flops__ += int(np.prod(input.shape)) module.__flops__ += int(np.prod(input[0].shape))
def norm_flops_counter_hook(module, input, output): def norm_flops_counter_hook(module: nn.Module, input: tuple,
input = input[0] output: torch.Tensor) -> None:
batch_flops = np.prod(input[0].shape)
batch_flops = np.prod(input.shape)
if (getattr(module, 'affine', False) if (getattr(module, 'affine', False)
or getattr(module, 'elementwise_affine', False)): or getattr(module, 'elementwise_affine', False)):
batch_flops *= 2 batch_flops *= 2
module.__flops__ += int(batch_flops) module.__flops__ += int(batch_flops)
def deconv_flops_counter_hook(conv_module, input, output): def deconv_flops_counter_hook(conv_module: nn.Module, input: tuple,
output: torch.Tensor) -> None:
# Can have multiple inputs, getting the first one # Can have multiple inputs, getting the first one
input = input[0] batch_size = input[0].shape[0]
input_height, input_width = input[0].shape[2:]
batch_size = input.shape[0]
input_height, input_width = input.shape[2:]
kernel_height, kernel_width = conv_module.kernel_size kernel_height, kernel_width = conv_module.kernel_size
in_channels = conv_module.in_channels in_channels = conv_module.in_channels
...@@ -465,11 +471,10 @@ def deconv_flops_counter_hook(conv_module, input, output): ...@@ -465,11 +471,10 @@ def deconv_flops_counter_hook(conv_module, input, output):
conv_module.__flops__ += int(overall_flops) conv_module.__flops__ += int(overall_flops)
def conv_flops_counter_hook(conv_module, input, output): def conv_flops_counter_hook(conv_module: nn.Module, input: tuple,
output: torch.Tensor) -> None:
# Can have multiple inputs, getting the first one # Can have multiple inputs, getting the first one
input = input[0] batch_size = input[0].shape[0]
batch_size = input.shape[0]
output_dims = list(output.shape[2:]) output_dims = list(output.shape[2:])
kernel_dims = list(conv_module.kernel_size) kernel_dims = list(conv_module.kernel_size)
...@@ -496,24 +501,23 @@ def conv_flops_counter_hook(conv_module, input, output): ...@@ -496,24 +501,23 @@ def conv_flops_counter_hook(conv_module, input, output):
conv_module.__flops__ += int(overall_flops) conv_module.__flops__ += int(overall_flops)
def batch_counter_hook(module, input, output): def batch_counter_hook(module: nn.Module, input: tuple, output: Any) -> None:
batch_size = 1 batch_size = 1
if len(input) > 0: if len(input) > 0:
# Can have multiple inputs, getting the first one # Can have multiple inputs, getting the first one
input = input[0] batch_size = len(input[0])
batch_size = len(input)
else: else:
warnings.warn('No positional inputs found for a module, ' warnings.warn('No positional inputs found for a module, '
'assuming batch size is 1.') 'assuming batch size is 1.')
module.__batch_counter__ += batch_size module.__batch_counter__ += batch_size
def add_batch_counter_variables_or_reset(module): def add_batch_counter_variables_or_reset(module: nn.Module) -> None:
module.__batch_counter__ = 0 module.__batch_counter__ = 0
def add_batch_counter_hook_function(module): def add_batch_counter_hook_function(module: nn.Module) -> None:
if hasattr(module, '__batch_counter_handle__'): if hasattr(module, '__batch_counter_handle__'):
return return
...@@ -521,13 +525,13 @@ def add_batch_counter_hook_function(module): ...@@ -521,13 +525,13 @@ def add_batch_counter_hook_function(module):
module.__batch_counter_handle__ = handle module.__batch_counter_handle__ = handle
def remove_batch_counter_hook_function(module): def remove_batch_counter_hook_function(module: nn.Module) -> None:
if hasattr(module, '__batch_counter_handle__'): if hasattr(module, '__batch_counter_handle__'):
module.__batch_counter_handle__.remove() module.__batch_counter_handle__.remove()
del module.__batch_counter_handle__ del module.__batch_counter_handle__
def add_flops_counter_variable_or_reset(module): def add_flops_counter_variable_or_reset(module: nn.Module) -> None:
if is_supported_instance(module): if is_supported_instance(module):
if hasattr(module, '__flops__') or hasattr(module, '__params__'): if hasattr(module, '__flops__') or hasattr(module, '__params__'):
warnings.warn('variables __flops__ or __params__ are already ' warnings.warn('variables __flops__ or __params__ are already '
...@@ -537,20 +541,20 @@ def add_flops_counter_variable_or_reset(module): ...@@ -537,20 +541,20 @@ def add_flops_counter_variable_or_reset(module):
module.__params__ = get_model_parameters_number(module) module.__params__ = get_model_parameters_number(module)
def is_supported_instance(module): def is_supported_instance(module: nn.Module) -> bool:
if type(module) in get_modules_mapping(): if type(module) in get_modules_mapping():
return True return True
return False return False
def remove_flops_counter_hook_function(module): def remove_flops_counter_hook_function(module: nn.Module) -> None:
if is_supported_instance(module): if is_supported_instance(module):
if hasattr(module, '__flops_handle__'): if hasattr(module, '__flops_handle__'):
module.__flops_handle__.remove() module.__flops_handle__.remove()
del module.__flops_handle__ del module.__flops_handle__
def get_modules_mapping(): def get_modules_mapping() -> Dict:
return { return {
# convolutions # convolutions
nn.Conv1d: conv_flops_counter_hook, nn.Conv1d: conv_flops_counter_hook,
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
def _fuse_conv_bn(conv, bn): def _fuse_conv_bn(conv: nn.Module, bn: nn.Module) -> nn.Module:
"""Fuse conv and bn into one module. """Fuse conv and bn into one module.
Args: Args:
...@@ -24,7 +24,7 @@ def _fuse_conv_bn(conv, bn): ...@@ -24,7 +24,7 @@ def _fuse_conv_bn(conv, bn):
return conv return conv
def fuse_conv_bn(module): def fuse_conv_bn(module: nn.Module) -> nn.Module:
"""Recursively fuse conv and bn in a module. """Recursively fuse conv and bn in a module.
During inference, the functionary of batch norm layers is turned off During inference, the functionary of batch norm layers is turned off
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
import torch.nn as nn
import mmcv import mmcv
class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm): class _BatchNormXd(nn.modules.batchnorm._BatchNorm):
"""A general BatchNorm layer without input dimension check. """A general BatchNorm layer without input dimension check.
Reproduced from @kapily's work: Reproduced from @kapily's work:
...@@ -15,11 +16,11 @@ class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm): ...@@ -15,11 +16,11 @@ class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
SyncBatchNorm. SyncBatchNorm.
""" """
def _check_input_dim(self, input): def _check_input_dim(self, input: torch.Tensor):
return return
def revert_sync_batchnorm(module): def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
"""Helper function to convert all `SyncBatchNorm` (SyncBN) and """Helper function to convert all `SyncBatchNorm` (SyncBN) and
`mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
`BatchNormXd` layers. `BatchNormXd` layers.
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import copy import copy
import math import math
import warnings import warnings
from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -13,7 +14,7 @@ from mmcv.utils import Registry, build_from_cfg, get_logger, print_log ...@@ -13,7 +14,7 @@ from mmcv.utils import Registry, build_from_cfg, get_logger, print_log
INITIALIZERS = Registry('initializer') INITIALIZERS = Registry('initializer')
def update_init_info(module, init_info): def update_init_info(module: nn.Module, init_info: str) -> None:
"""Update the `_params_init_info` in the module if the value of parameters """Update the `_params_init_info` in the module if the value of parameters
are changed. are changed.
...@@ -45,14 +46,17 @@ def update_init_info(module, init_info): ...@@ -45,14 +46,17 @@ def update_init_info(module, init_info):
module._params_init_info[param]['tmp_mean_value'] = mean_value module._params_init_info[param]['tmp_mean_value'] = mean_value
def constant_init(module, val, bias=0): def constant_init(module: nn.Module, val: float, bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None: if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val) nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None: if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
def xavier_init(module, gain=1, bias=0, distribution='normal'): def xavier_init(module: nn.Module,
gain: float = 1,
bias: float = 0,
distribution: str = 'normal') -> None:
assert distribution in ['uniform', 'normal'] assert distribution in ['uniform', 'normal']
if hasattr(module, 'weight') and module.weight is not None: if hasattr(module, 'weight') and module.weight is not None:
if distribution == 'uniform': if distribution == 'uniform':
...@@ -63,7 +67,10 @@ def xavier_init(module, gain=1, bias=0, distribution='normal'): ...@@ -63,7 +67,10 @@ def xavier_init(module, gain=1, bias=0, distribution='normal'):
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
def normal_init(module, mean=0, std=1, bias=0): def normal_init(module: nn.Module,
mean: float = 0,
std: float = 1,
bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None: if hasattr(module, 'weight') and module.weight is not None:
nn.init.normal_(module.weight, mean, std) nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias') and module.bias is not None: if hasattr(module, 'bias') and module.bias is not None:
...@@ -82,19 +89,22 @@ def trunc_normal_init(module: nn.Module, ...@@ -82,19 +89,22 @@ def trunc_normal_init(module: nn.Module,
nn.init.constant_(module.bias, bias) # type: ignore nn.init.constant_(module.bias, bias) # type: ignore
def uniform_init(module, a=0, b=1, bias=0): def uniform_init(module: nn.Module,
a: float = 0,
b: float = 1,
bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None: if hasattr(module, 'weight') and module.weight is not None:
nn.init.uniform_(module.weight, a, b) nn.init.uniform_(module.weight, a, b)
if hasattr(module, 'bias') and module.bias is not None: if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
def kaiming_init(module, def kaiming_init(module: nn.Module,
a=0, a: float = 0,
mode='fan_out', mode: str = 'fan_out',
nonlinearity='relu', nonlinearity: str = 'relu',
bias=0, bias: float = 0,
distribution='normal'): distribution: str = 'normal') -> None:
assert distribution in ['uniform', 'normal'] assert distribution in ['uniform', 'normal']
if hasattr(module, 'weight') and module.weight is not None: if hasattr(module, 'weight') and module.weight is not None:
if distribution == 'uniform': if distribution == 'uniform':
...@@ -107,7 +117,7 @@ def kaiming_init(module, ...@@ -107,7 +117,7 @@ def kaiming_init(module,
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
def caffe2_xavier_init(module, bias=0): def caffe2_xavier_init(module: nn.Module, bias: float = 0) -> None:
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code # Acknowledgment to FAIR's internal code
kaiming_init( kaiming_init(
...@@ -119,19 +129,23 @@ def caffe2_xavier_init(module, bias=0): ...@@ -119,19 +129,23 @@ def caffe2_xavier_init(module, bias=0):
distribution='uniform') distribution='uniform')
def bias_init_with_prob(prior_prob): def bias_init_with_prob(prior_prob: float) -> float:
"""initialize conv/fc bias value according to a given probability value.""" """initialize conv/fc bias value according to a given probability value."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob)) bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init return bias_init
def _get_bases_name(m): def _get_bases_name(m: nn.Module) -> List[str]:
return [b.__name__ for b in m.__class__.__bases__] return [b.__name__ for b in m.__class__.__bases__]
class BaseInit: class BaseInit:
def __init__(self, *, bias=0, bias_prob=None, layer=None): def __init__(self,
*,
bias: float = 0,
bias_prob: Optional[float] = None,
layer: Union[str, List, None] = None):
self.wholemodule = False self.wholemodule = False
if not isinstance(bias, (int, float)): if not isinstance(bias, (int, float)):
raise TypeError(f'bias must be a number, but got a {type(bias)}') raise TypeError(f'bias must be a number, but got a {type(bias)}')
...@@ -154,7 +168,7 @@ class BaseInit: ...@@ -154,7 +168,7 @@ class BaseInit:
self.bias = bias self.bias = bias
self.layer = [layer] if isinstance(layer, str) else layer self.layer = [layer] if isinstance(layer, str) else layer
def _get_init_info(self): def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}, bias={self.bias}' info = f'{self.__class__.__name__}, bias={self.bias}'
return info return info
...@@ -172,11 +186,11 @@ class ConstantInit(BaseInit): ...@@ -172,11 +186,11 @@ class ConstantInit(BaseInit):
Defaults to None. Defaults to None.
""" """
def __init__(self, val, **kwargs): def __init__(self, val: Union[int, float], **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.val = val self.val = val
def __call__(self, module): def __call__(self, module: nn.Module) -> None:
def init(m): def init(m):
if self.wholemodule: if self.wholemodule:
...@@ -191,7 +205,7 @@ class ConstantInit(BaseInit): ...@@ -191,7 +205,7 @@ class ConstantInit(BaseInit):
if hasattr(module, '_params_init_info'): if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info()) update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self): def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}' info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
return info return info
...@@ -214,12 +228,15 @@ class XavierInit(BaseInit): ...@@ -214,12 +228,15 @@ class XavierInit(BaseInit):
Defaults to None. Defaults to None.
""" """
def __init__(self, gain=1, distribution='normal', **kwargs): def __init__(self,
gain: float = 1,
distribution: str = 'normal',
**kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.gain = gain self.gain = gain
self.distribution = distribution self.distribution = distribution
def __call__(self, module): def __call__(self, module: nn.Module) -> None:
def init(m): def init(m):
if self.wholemodule: if self.wholemodule:
...@@ -234,7 +251,7 @@ class XavierInit(BaseInit): ...@@ -234,7 +251,7 @@ class XavierInit(BaseInit):
if hasattr(module, '_params_init_info'): if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info()) update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self): def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: gain={self.gain}, ' \ info = f'{self.__class__.__name__}: gain={self.gain}, ' \
f'distribution={self.distribution}, bias={self.bias}' f'distribution={self.distribution}, bias={self.bias}'
return info return info
...@@ -257,12 +274,12 @@ class NormalInit(BaseInit): ...@@ -257,12 +274,12 @@ class NormalInit(BaseInit):
""" """
def __init__(self, mean=0, std=1, **kwargs): def __init__(self, mean: float = 0, std: float = 1, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.mean = mean self.mean = mean
self.std = std self.std = std
def __call__(self, module): def __call__(self, module: nn.Module) -> None:
def init(m): def init(m):
if self.wholemodule: if self.wholemodule:
...@@ -277,7 +294,7 @@ class NormalInit(BaseInit): ...@@ -277,7 +294,7 @@ class NormalInit(BaseInit):
if hasattr(module, '_params_init_info'): if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info()) update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self): def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: mean={self.mean},' \ info = f'{self.__class__.__name__}: mean={self.mean},' \
f' std={self.std}, bias={self.bias}' f' std={self.std}, bias={self.bias}'
return info return info
...@@ -355,12 +372,12 @@ class UniformInit(BaseInit): ...@@ -355,12 +372,12 @@ class UniformInit(BaseInit):
Defaults to None. Defaults to None.
""" """
def __init__(self, a=0, b=1, **kwargs): def __init__(self, a: float = 0., b: float = 1., **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.a = a self.a = a
self.b = b self.b = b
def __call__(self, module): def __call__(self, module: nn.Module) -> None:
def init(m): def init(m):
if self.wholemodule: if self.wholemodule:
...@@ -375,7 +392,7 @@ class UniformInit(BaseInit): ...@@ -375,7 +392,7 @@ class UniformInit(BaseInit):
if hasattr(module, '_params_init_info'): if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info()) update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self): def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: a={self.a},' \ info = f'{self.__class__.__name__}: a={self.a},' \
f' b={self.b}, bias={self.bias}' f' b={self.b}, bias={self.bias}'
return info return info
...@@ -409,10 +426,10 @@ class KaimingInit(BaseInit): ...@@ -409,10 +426,10 @@ class KaimingInit(BaseInit):
""" """
def __init__(self, def __init__(self,
a=0, a: float = 0,
mode='fan_out', mode: str = 'fan_out',
nonlinearity='relu', nonlinearity: str = 'relu',
distribution='normal', distribution: str = 'normal',
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.a = a self.a = a
...@@ -420,7 +437,7 @@ class KaimingInit(BaseInit): ...@@ -420,7 +437,7 @@ class KaimingInit(BaseInit):
self.nonlinearity = nonlinearity self.nonlinearity = nonlinearity
self.distribution = distribution self.distribution = distribution
def __call__(self, module): def __call__(self, module: nn.Module) -> None:
def init(m): def init(m):
if self.wholemodule: if self.wholemodule:
...@@ -437,7 +454,7 @@ class KaimingInit(BaseInit): ...@@ -437,7 +454,7 @@ class KaimingInit(BaseInit):
if hasattr(module, '_params_init_info'): if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info()) update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self): def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \ info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
f'nonlinearity={self.nonlinearity}, ' \ f'nonlinearity={self.nonlinearity}, ' \
f'distribution ={self.distribution}, bias={self.bias}' f'distribution ={self.distribution}, bias={self.bias}'
...@@ -456,7 +473,7 @@ class Caffe2XavierInit(KaimingInit): ...@@ -456,7 +473,7 @@ class Caffe2XavierInit(KaimingInit):
distribution='uniform', distribution='uniform',
**kwargs) **kwargs)
def __call__(self, module): def __call__(self, module: nn.Module) -> None:
super().__call__(module) super().__call__(module)
...@@ -475,12 +492,15 @@ class PretrainedInit: ...@@ -475,12 +492,15 @@ class PretrainedInit:
map_location (str): map tensors into proper locations. map_location (str): map tensors into proper locations.
""" """
def __init__(self, checkpoint, prefix=None, map_location=None): def __init__(self,
checkpoint: str,
prefix: Optional[str] = None,
map_location: Optional[str] = None):
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.prefix = prefix self.prefix = prefix
self.map_location = map_location self.map_location = map_location
def __call__(self, module): def __call__(self, module: nn.Module) -> None:
from mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint, from mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
load_state_dict) load_state_dict)
logger = get_logger('mmcv') logger = get_logger('mmcv')
...@@ -503,12 +523,14 @@ class PretrainedInit: ...@@ -503,12 +523,14 @@ class PretrainedInit:
if hasattr(module, '_params_init_info'): if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info()) update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self): def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: load from {self.checkpoint}' info = f'{self.__class__.__name__}: load from {self.checkpoint}'
return info return info
def _initialize(module, cfg, wholemodule=False): def _initialize(module: nn.Module,
cfg: Dict,
wholemodule: bool = False) -> None:
func = build_from_cfg(cfg, INITIALIZERS) func = build_from_cfg(cfg, INITIALIZERS)
# wholemodule flag is for override mode, there is no layer key in override # wholemodule flag is for override mode, there is no layer key in override
# and initializer will give init values for the whole module with the name # and initializer will give init values for the whole module with the name
...@@ -517,7 +539,8 @@ def _initialize(module, cfg, wholemodule=False): ...@@ -517,7 +539,8 @@ def _initialize(module, cfg, wholemodule=False):
func(module) func(module)
def _initialize_override(module, override, cfg): def _initialize_override(module: nn.Module, override: Union[Dict, List],
cfg: Dict) -> None:
if not isinstance(override, (dict, list)): if not isinstance(override, (dict, list)):
raise TypeError(f'override must be a dict or a list of dict, \ raise TypeError(f'override must be a dict or a list of dict, \
but got {type(override)}') but got {type(override)}')
...@@ -547,7 +570,7 @@ def _initialize_override(module, override, cfg): ...@@ -547,7 +570,7 @@ def _initialize_override(module, override, cfg):
f'but init_cfg is {cp_override}.') f'but init_cfg is {cp_override}.')
def initialize(module, init_cfg): def initialize(module: nn.Module, init_cfg: Union[Dict, List[dict]]) -> None:
r"""Initialize a module. r"""Initialize a module.
Args: Args:
......
...@@ -40,7 +40,7 @@ def _get_mmcv_home(): ...@@ -40,7 +40,7 @@ def _get_mmcv_home():
def load_state_dict(module: torch.nn.Module, def load_state_dict(module: torch.nn.Module,
state_dict: OrderedDict, state_dict: Union[dict, OrderedDict],
strict: bool = False, strict: bool = False,
logger: Optional[logging.Logger] = None) -> None: logger: Optional[logging.Logger] = None) -> None:
"""Load state_dict to a module. """Load state_dict to a module.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment