"megatron/vscode:/vscode.git/clone" did not exist on "a6e00d97f140f04cff60a4eac46ac07a5b4cb9d9"
Unverified Commit 19a02415 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Refactor] Use MODELS registry in mmengine and delete basemodule (#2172)

* change MODELS to mmengine, delete basemodule

* fix unit test

* remove build from cfg

* add comment and rename TARGET_MODELS to registry

* refine cnn docs

* remove unnecessary check

* refine as comment

* refine build_xxx_conv error message

* fix lint

* fix import registry from mmcv

* remove unused file
parent f6fd6c21
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry
CONV_LAYERS = Registry('conv layer')
NORM_LAYERS = Registry('norm layer')
ACTIVATION_LAYERS = Registry('activation layer')
PADDING_LAYERS = Registry('padding layer')
UPSAMPLE_LAYERS = Registry('upsample layer')
PLUGIN_LAYERS = Registry('plugin layer')
DROPOUT_LAYERS = Registry('drop out layers')
POSITIONAL_ENCODING = Registry('position encoding')
ATTENTION = Registry('attention')
FEEDFORWARD_NETWORK = Registry('feed-forward Network')
TRANSFORMER_LAYER = Registry('transformerLayer')
TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.registry import MODELS
from .registry import ACTIVATION_LAYERS
@ACTIVATION_LAYERS.register_module()
@MODELS.register_module()
class Swish(nn.Module):
"""Swish Module.
......
......@@ -7,15 +7,14 @@ from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine import ConfigDict
from mmengine.model import BaseModule, ModuleList, Sequential
from mmengine.registry import MODELS
from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer,
build_norm_layer)
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning,
to_2tuple)
from mmcv.utils import deprecated_api_warning, to_2tuple
from .drop import build_dropout
from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try:
......@@ -37,27 +36,27 @@ except ImportError:
def build_positional_encoding(cfg, default_args=None):
"""Builder for Position Encoding."""
return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
return MODELS.build(cfg, default_args=default_args)
def build_attention(cfg, default_args=None):
"""Builder for attention."""
return build_from_cfg(cfg, ATTENTION, default_args)
return MODELS.build(cfg, default_args=default_args)
def build_feedforward_network(cfg, default_args=None):
"""Builder for feed-forward network (FFN)."""
return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args)
return MODELS.build(cfg, default_args=default_args)
def build_transformer_layer(cfg, default_args=None):
"""Builder for transformer layer."""
return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
return MODELS.build(cfg, default_args=default_args)
def build_transformer_layer_sequence(cfg, default_args=None):
"""Builder for transformer encoder and transformer decoder."""
return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args)
return MODELS.build(cfg, default_args=default_args)
class AdaptivePadding(nn.Module):
......@@ -403,7 +402,7 @@ class PatchMerging(BaseModule):
return x, output_size
@ATTENTION.register_module()
@MODELS.register_module()
class MultiheadAttention(BaseModule):
"""A wrapper for ``torch.nn.MultiheadAttention``.
......@@ -551,7 +550,7 @@ class MultiheadAttention(BaseModule):
return identity + self.dropout_layer(self.proj_drop(out))
@FEEDFORWARD_NETWORK.register_module()
@MODELS.register_module()
class FFN(BaseModule):
"""Implements feed-forward networks (FFNs) with identity connection.
......@@ -628,7 +627,7 @@ class FFN(BaseModule):
return identity + self.dropout_layer(out)
@TRANSFORMER_LAYER.register_module()
@MODELS.register_module()
class BaseTransformerLayer(BaseModule):
"""Base `TransformerLayer` for vision transformer.
......@@ -859,7 +858,7 @@ class BaseTransformerLayer(BaseModule):
return query
@TRANSFORMER_LAYER_SEQUENCE.register_module()
@MODELS.register_module()
class TransformerLayerSequence(BaseModule):
"""Base class for TransformerEncoder and TransformerDecoder in vision
transformer.
......
......@@ -4,15 +4,14 @@ from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model.utils import xavier_init
from mmengine.registry import MODELS
from ..utils import xavier_init
from .registry import UPSAMPLE_LAYERS
MODELS.register_module('nearest', module=nn.Upsample)
MODELS.register_module('bilinear', module=nn.Upsample)
UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
@MODELS.register_module(name='pixel_shuffle')
class PixelShufflePack(nn.Module):
"""Pixel Shuffle upsample layer.
......@@ -76,11 +75,15 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if layer_type not in UPSAMPLE_LAYERS:
raise KeyError(f'Unrecognized upsample type {layer_type}')
else:
upsample = UPSAMPLE_LAYERS.get(layer_type)
# Switch registry to the target scope. If `upsample` cannot be found
# in the registry, fallback to search `upsample` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
upsample = registry.get(layer_type)
if upsample is None:
raise KeyError(f'Cannot find {upsample} in registry under scope '
f'name {registry.scope}')
if upsample is nn.Upsample:
cfg_['mode'] = layer_type
layer = upsample(*args, **kwargs, **cfg_)
......
......@@ -9,10 +9,9 @@ import math
import torch
import torch.nn as nn
from mmengine.registry import MODELS
from torch.nn.modules.utils import _pair, _triple
from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
if torch.__version__ == 'parrots':
TORCH_VERSION = torch.__version__
else:
......@@ -38,7 +37,7 @@ class NewEmptyTensorOp(torch.autograd.Function):
return NewEmptyTensorOp.apply(grad, shape), None
@CONV_LAYERS.register_module('Conv', force=True)
@MODELS.register_module('Conv', force=True)
class Conv2d(nn.Conv2d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -59,7 +58,7 @@ class Conv2d(nn.Conv2d):
return super().forward(x)
@CONV_LAYERS.register_module('Conv3d', force=True)
@MODELS.register_module('Conv3d', force=True)
class Conv3d(nn.Conv3d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -80,9 +79,8 @@ class Conv3d(nn.Conv3d):
return super().forward(x)
@CONV_LAYERS.register_module()
@CONV_LAYERS.register_module('deconv')
@UPSAMPLE_LAYERS.register_module('deconv', force=True)
@MODELS.register_module()
@MODELS.register_module('deconv')
class ConvTranspose2d(nn.ConvTranspose2d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -103,9 +101,8 @@ class ConvTranspose2d(nn.ConvTranspose2d):
return super().forward(x)
@CONV_LAYERS.register_module()
@CONV_LAYERS.register_module('deconv3d')
@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
@MODELS.register_module()
@MODELS.register_module('deconv3d')
class ConvTranspose3d(nn.ConvTranspose3d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
......
# Copyright (c) OpenMMLab. All rights reserved.
from ..runner import Sequential
from ..utils import Registry, build_from_cfg
def build_model_from_cfg(cfg, registry, default_args=None):
"""Build a PyTorch model from config dict(s). Different from
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
Args:
cfg (dict, list[dict]): The config of modules, is is either a config
dict or a list of config dicts. If cfg is a list, a
the built modules will be wrapped with ``nn.Sequential``.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
MODELS = Registry('model', build_func=build_model_from_cfg)
......@@ -4,10 +4,9 @@ from typing import Optional, Sequence, Tuple, Union
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmengine.model.utils import constant_init, kaiming_init
from torch import Tensor
from .utils import constant_init, kaiming_init
def conv3x3(in_planes: int,
out_planes: int,
......
......@@ -2,18 +2,7 @@
from .flops_counter import get_model_complexity_info
from .fuse_conv_bn import fuse_conv_bn
from .sync_bn import revert_sync_batchnorm
from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
KaimingInit, NormalInit, PretrainedInit,
TruncNormalInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init,
constant_init, initialize, kaiming_init, normal_init,
trunc_normal_init, uniform_init, xavier_init)
__all__ = [
'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'Caffe2XavierInit', 'revert_sync_batchnorm'
'get_model_complexity_info', 'fuse_conv_bn', 'revert_sync_batchnorm'
]
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
import warnings
from typing import Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from mmcv.utils import Registry, build_from_cfg, get_logger, print_log
INITIALIZERS = Registry('initializer')
def update_init_info(module: nn.Module, init_info: str) -> None:
"""Update the `_params_init_info` in the module if the value of parameters
are changed.
Args:
module (obj:`nn.Module`): The module of PyTorch with a user-defined
attribute `_params_init_info` which records the initialization
information.
init_info (str): The string that describes the initialization.
"""
assert hasattr(
module,
'_params_init_info'), f'Can not find `_params_init_info` in {module}'
for name, param in module.named_parameters():
assert param in module._params_init_info, (
f'Find a new :obj:`Parameter` '
f'named `{name}` during executing the '
f'`init_weights` of '
f'`{module.__class__.__name__}`. '
f'Please do not add or '
f'replace parameters during executing '
f'the `init_weights`. ')
# The parameter has been changed during executing the
# `init_weights` of module
mean_value = param.data.mean()
if module._params_init_info[param]['tmp_mean_value'] != mean_value:
module._params_init_info[param]['init_info'] = init_info
module._params_init_info[param]['tmp_mean_value'] = mean_value
def constant_init(module: nn.Module, val: float, bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def xavier_init(module: nn.Module,
gain: float = 1,
bias: float = 0,
distribution: str = 'normal') -> None:
assert distribution in ['uniform', 'normal']
if hasattr(module, 'weight') and module.weight is not None:
if distribution == 'uniform':
nn.init.xavier_uniform_(module.weight, gain=gain)
else:
nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def normal_init(module: nn.Module,
mean: float = 0,
std: float = 1,
bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None:
nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def trunc_normal_init(module: nn.Module,
mean: float = 0,
std: float = 1,
a: float = -2,
b: float = 2,
bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None:
trunc_normal_(module.weight, mean, std, a, b) # type: ignore
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) # type: ignore
def uniform_init(module: nn.Module,
a: float = 0,
b: float = 1,
bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None:
nn.init.uniform_(module.weight, a, b)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def kaiming_init(module: nn.Module,
a: float = 0,
mode: str = 'fan_out',
nonlinearity: str = 'relu',
bias: float = 0,
distribution: str = 'normal') -> None:
assert distribution in ['uniform', 'normal']
if hasattr(module, 'weight') and module.weight is not None:
if distribution == 'uniform':
nn.init.kaiming_uniform_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
else:
nn.init.kaiming_normal_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def caffe2_xavier_init(module: nn.Module, bias: float = 0) -> None:
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
kaiming_init(
module,
a=1,
mode='fan_in',
nonlinearity='leaky_relu',
bias=bias,
distribution='uniform')
def bias_init_with_prob(prior_prob: float) -> float:
"""initialize conv/fc bias value according to a given probability value."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init
def _get_bases_name(m: nn.Module) -> List[str]:
return [b.__name__ for b in m.__class__.__bases__]
class BaseInit:
def __init__(self,
*,
bias: float = 0,
bias_prob: Optional[float] = None,
layer: Union[str, List, None] = None):
self.wholemodule = False
if not isinstance(bias, (int, float)):
raise TypeError(f'bias must be a number, but got a {type(bias)}')
if bias_prob is not None:
if not isinstance(bias_prob, float):
raise TypeError(f'bias_prob type must be float, \
but got {type(bias_prob)}')
if layer is not None:
if not isinstance(layer, (str, list)):
raise TypeError(f'layer must be a str or a list of str, \
but got a {type(layer)}')
else:
layer = []
if bias_prob is not None:
self.bias = bias_init_with_prob(bias_prob)
else:
self.bias = bias
self.layer = [layer] if isinstance(layer, str) else layer
def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Constant')
class ConstantInit(BaseInit):
"""Initialize module parameters with constant values.
Args:
val (int | float): the value to fill the weights in the module with
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, val: Union[int, float], **kwargs):
super().__init__(**kwargs)
self.val = val
def __call__(self, module: nn.Module) -> None:
def init(m):
if self.wholemodule:
constant_init(m, self.val, self.bias)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
constant_init(m, self.val, self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Xavier')
class XavierInit(BaseInit):
r"""Initialize module parameters with values according to the method
described in `Understanding the difficulty of training deep feedforward
neural networks - Glorot, X. & Bengio, Y. (2010).
<http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_
Args:
gain (int | float): an optional scaling factor. Defaults to 1.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
distribution (str): distribution either be ``'normal'``
or ``'uniform'``. Defaults to ``'normal'``.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self,
gain: float = 1,
distribution: str = 'normal',
**kwargs):
super().__init__(**kwargs)
self.gain = gain
self.distribution = distribution
def __call__(self, module: nn.Module) -> None:
def init(m):
if self.wholemodule:
xavier_init(m, self.gain, self.bias, self.distribution)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
xavier_init(m, self.gain, self.bias, self.distribution)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: gain={self.gain}, ' \
f'distribution={self.distribution}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Normal')
class NormalInit(BaseInit):
r"""Initialize module parameters with the values drawn from the normal
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
Args:
mean (int | float):the mean of the normal distribution. Defaults to 0.
std (int | float): the standard deviation of the normal distribution.
Defaults to 1.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, mean: float = 0, std: float = 1, **kwargs):
super().__init__(**kwargs)
self.mean = mean
self.std = std
def __call__(self, module: nn.Module) -> None:
def init(m):
if self.wholemodule:
normal_init(m, self.mean, self.std, self.bias)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
normal_init(m, self.mean, self.std, self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: mean={self.mean},' \
f' std={self.std}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='TruncNormal')
class TruncNormalInit(BaseInit):
r"""Initialize module parameters with the values drawn from the normal
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
outside :math:`[a, b]`.
Args:
mean (float): the mean of the normal distribution. Defaults to 0.
std (float): the standard deviation of the normal distribution.
Defaults to 1.
a (float): The minimum cutoff value.
b ( float): The maximum cutoff value.
bias (float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self,
mean: float = 0,
std: float = 1,
a: float = -2,
b: float = 2,
**kwargs) -> None:
super().__init__(**kwargs)
self.mean = mean
self.std = std
self.a = a
self.b = b
def __call__(self, module: nn.Module) -> None:
def init(m):
if self.wholemodule:
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
self.bias)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
f' mean={self.mean}, std={self.std}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Uniform')
class UniformInit(BaseInit):
r"""Initialize module parameters with values drawn from the uniform
distribution :math:`\mathcal{U}(a, b)`.
Args:
a (int | float): the lower bound of the uniform distribution.
Defaults to 0.
b (int | float): the upper bound of the uniform distribution.
Defaults to 1.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, a: float = 0., b: float = 1., **kwargs):
super().__init__(**kwargs)
self.a = a
self.b = b
def __call__(self, module: nn.Module) -> None:
def init(m):
if self.wholemodule:
uniform_init(m, self.a, self.b, self.bias)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
uniform_init(m, self.a, self.b, self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: a={self.a},' \
f' b={self.b}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Kaiming')
class KaimingInit(BaseInit):
r"""Initialize module parameters with the values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification - He, K. et al. (2015).
<https://www.cv-foundation.org/openaccess/content_iccv_2015/
papers/He_Delving_Deep_into_ICCV_2015_paper.pdf>`_
Args:
a (int | float): the negative slope of the rectifier used after this
layer (only used with ``'leaky_relu'``). Defaults to 0.
mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing
``'fan_in'`` preserves the magnitude of the variance of the weights
in the forward pass. Choosing ``'fan_out'`` preserves the
magnitudes in the backwards pass. Defaults to ``'fan_out'``.
nonlinearity (str): the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
Defaults to 'relu'.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
distribution (str): distribution either be ``'normal'`` or
``'uniform'``. Defaults to ``'normal'``.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self,
a: float = 0,
mode: str = 'fan_out',
nonlinearity: str = 'relu',
distribution: str = 'normal',
**kwargs):
super().__init__(**kwargs)
self.a = a
self.mode = mode
self.nonlinearity = nonlinearity
self.distribution = distribution
def __call__(self, module: nn.Module) -> None:
def init(m):
if self.wholemodule:
kaiming_init(m, self.a, self.mode, self.nonlinearity,
self.bias, self.distribution)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
kaiming_init(m, self.a, self.mode, self.nonlinearity,
self.bias, self.distribution)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
f'nonlinearity={self.nonlinearity}, ' \
f'distribution ={self.distribution}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Caffe2Xavier')
class Caffe2XavierInit(KaimingInit):
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
def __init__(self, **kwargs):
super().__init__(
a=1,
mode='fan_in',
nonlinearity='leaky_relu',
distribution='uniform',
**kwargs)
def __call__(self, module: nn.Module) -> None:
super().__call__(module)
@INITIALIZERS.register_module(name='Pretrained')
class PretrainedInit:
"""Initialize module by loading a pretrained model.
Args:
checkpoint (str): the checkpoint file of the pretrained model should
be load.
prefix (str, optional): the prefix of a sub-module in the pretrained
model. it is for loading a part of the pretrained model to
initialize. For example, if we would like to only load the
backbone of a detector model, we can set ``prefix='backbone.'``.
Defaults to None.
map_location (str): map tensors into proper locations.
"""
def __init__(self,
checkpoint: str,
prefix: Optional[str] = None,
map_location: Optional[str] = None):
self.checkpoint = checkpoint
self.prefix = prefix
self.map_location = map_location
def __call__(self, module: nn.Module) -> None:
from mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
load_state_dict)
logger = get_logger('mmcv')
if self.prefix is None:
print_log(f'load model from: {self.checkpoint}', logger=logger)
load_checkpoint(
module,
self.checkpoint,
map_location=self.map_location,
strict=False,
logger=logger)
else:
print_log(
f'load {self.prefix} in model from: {self.checkpoint}',
logger=logger)
state_dict = _load_checkpoint_with_prefix(
self.prefix, self.checkpoint, map_location=self.map_location)
load_state_dict(module, state_dict, strict=False, logger=logger)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: load from {self.checkpoint}'
return info
def _initialize(module: nn.Module,
cfg: Dict,
wholemodule: bool = False) -> None:
func = build_from_cfg(cfg, INITIALIZERS)
# wholemodule flag is for override mode, there is no layer key in override
# and initializer will give init values for the whole module with the name
# in override.
func.wholemodule = wholemodule
func(module)
def _initialize_override(module: nn.Module, override: Union[Dict, List],
cfg: Dict) -> None:
if not isinstance(override, (dict, list)):
raise TypeError(f'override must be a dict or a list of dict, \
but got {type(override)}')
override = [override] if isinstance(override, dict) else override
for override_ in override:
cp_override = copy.deepcopy(override_)
name = cp_override.pop('name', None)
if name is None:
raise ValueError('`override` must contain the key "name",'
f'but got {cp_override}')
# if override only has name key, it means use args in init_cfg
if not cp_override:
cp_override.update(cfg)
# if override has name key and other args except type key, it will
# raise error
elif 'type' not in cp_override.keys():
raise ValueError(
f'`override` need "type" key, but got {cp_override}')
if hasattr(module, name):
_initialize(getattr(module, name), cp_override, wholemodule=True)
else:
raise RuntimeError(f'module did not have attribute {name}, '
f'but init_cfg is {cp_override}.')
def initialize(module: nn.Module, init_cfg: Union[Dict, List[dict]]) -> None:
r"""Initialize a module.
Args:
module (``torch.nn.Module``): the module will be initialized.
init_cfg (dict | list[dict]): initialization configuration dict to
define initializer. OpenMMLab has implemented 6 initializers
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
``Kaiming``, and ``Pretrained``.
Example:
>>> module = nn.Linear(2, 3, bias=True)
>>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
>>> initialize(module, init_cfg)
>>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
>>> # define key ``'layer'`` for initializing layer with different
>>> # configuration
>>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
dict(type='Constant', layer='Linear', val=2)]
>>> initialize(module, init_cfg)
>>> # define key``'override'`` to initialize some specific part in
>>> # module
>>> class FooNet(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.feat = nn.Conv2d(3, 16, 3)
>>> self.reg = nn.Conv2d(16, 10, 3)
>>> self.cls = nn.Conv2d(16, 5, 3)
>>> model = FooNet()
>>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
>>> override=dict(type='Constant', name='reg', val=3, bias=4))
>>> initialize(model, init_cfg)
>>> model = ResNet(depth=50)
>>> # Initialize weights with the pretrained model.
>>> init_cfg = dict(type='Pretrained',
checkpoint='torchvision://resnet50')
>>> initialize(model, init_cfg)
>>> # Initialize weights of a sub-module with the specific part of
>>> # a pretrained model by using "prefix".
>>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
>>> 'retinanet_r50_fpn_1x_coco/'\
>>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
>>> init_cfg = dict(type='Pretrained',
checkpoint=url, prefix='backbone.')
"""
if not isinstance(init_cfg, (dict, list)):
raise TypeError(f'init_cfg must be a dict or a list of dict, \
but got {type(init_cfg)}')
if isinstance(init_cfg, dict):
init_cfg = [init_cfg]
for cfg in init_cfg:
# should deeply copy the original config because cfg may be used by
# other modules, e.g., one init_cfg shared by multiple bottleneck
# blocks, the expected cfg will be changed after pop and will change
# the initialization behavior of other modules
cp_cfg = copy.deepcopy(cfg)
override = cp_cfg.pop('override', None)
_initialize(module, cp_cfg)
if override is not None:
cp_cfg.pop('layer', None)
_initialize_override(module, override, cp_cfg)
else:
# All attributes in module have same initialization.
pass
def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
b: float) -> Tensor:
# Method based on
# https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
# Modified from
# https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
'The distribution of values may be incorrect.',
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
lower = norm_cdf((a - mean) / std)
upper = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [lower, upper], then translate
# to [2lower-1, 2upper-1].
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor: Tensor,
mean: float = 0.,
std: float = 1.,
a: float = -2.,
b: float = 2.) -> Tensor:
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Modified from
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
Args:
tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
mean (float): the mean of the normal distribution.
std (float): the standard deviation of the normal distribution.
a (float): the minimum cutoff value.
b (float): the maximum cutoff value.
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
......@@ -3,10 +3,9 @@ import logging
from typing import List, Optional, Sequence, Tuple, Union
import torch.nn as nn
from mmengine.model.utils import constant_init, kaiming_init, normal_init
from torch import Tensor
from .utils import constant_init, kaiming_init, normal_init
def conv3x3(in_planes: int, out_planes: int, dilation: int = 1) -> nn.Module:
"""3x3 convolution with padding."""
......
......@@ -4,11 +4,12 @@ from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model.utils import normal_init, xavier_init
from mmengine.registry import MODELS
from torch import Tensor
from torch.autograd import Function
from torch.nn.modules.module import Module
from ..cnn import UPSAMPLE_LAYERS, normal_init, xavier_init
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', [
......@@ -219,7 +220,7 @@ class CARAFE(Module):
self.scale_factor)
@UPSAMPLE_LAYERS.register_module(name='carafe')
@MODELS.register_module(name='carafe')
class CARAFEPack(nn.Module):
"""A unified package of CARAFE upsampler that contains: 1) channel
compressor 2) content encoder 3) CARAFE op.
......
......@@ -2,8 +2,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.registry import MODELS
from mmcv.cnn import PLUGIN_LAYERS, Scale
from mmcv.cnn import Scale
def NEG_INF_DIAG(n: int, device: torch.device) -> torch.Tensor:
......@@ -15,7 +16,7 @@ def NEG_INF_DIAG(n: int, device: torch.device) -> torch.Tensor:
return torch.diag(torch.tensor(float('-inf')).to(device).repeat(n), 0)
@PLUGIN_LAYERS.register_module()
@MODELS.register_module()
class CrissCrossAttention(nn.Module):
"""Criss-Cross Attention Module.
......
......@@ -4,14 +4,15 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine import print_log
from mmengine.registry import MODELS
from torch import Tensor
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single
from mmcv.utils import deprecated_api_warning
from ..cnn import CONV_LAYERS
from ..utils import ext_loader, print_log
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', [
'deform_conv_forward', 'deform_conv_backward_input',
......@@ -330,7 +331,7 @@ class DeformConv2d(nn.Module):
return s
@CONV_LAYERS.register_module('DCN')
@MODELS.register_module('DCN')
class DeformConv2dPack(DeformConv2d):
"""A Deformable Conv Encapsulation that acts as normal Conv layers.
......
......@@ -4,13 +4,14 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from mmengine import print_log
from mmengine.registry import MODELS
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single
from mmcv.utils import deprecated_api_warning
from ..cnn import CONV_LAYERS
from ..utils import ext_loader, print_log
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext',
......@@ -208,7 +209,7 @@ class ModulatedDeformConv2d(nn.Module):
self.deform_groups)
@CONV_LAYERS.register_module('DCNv2')
@MODELS.register_module('DCNv2')
class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv
layers.
......
......@@ -6,13 +6,13 @@ from typing import Optional, no_type_check
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmengine.model.utils import constant_init, xavier_init
from mmengine.registry import MODELS
from torch.autograd.function import Function, once_differentiable
import mmcv
from mmcv import deprecated_api_warning
from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.runner import BaseModule
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
......@@ -156,7 +156,7 @@ def multi_scale_deformable_attn_pytorch(
return output.transpose(1, 2).contiguous()
@ATTENTION.register_module()
@MODELS.register_module()
class MultiScaleDeformableAttention(BaseModule):
"""An attention module used in Deformable-Detr.
......
......@@ -2,13 +2,15 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model.utils import constant_init
from mmengine.registry import MODELS
from mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init
from mmcv.cnn import ConvAWS2d
from mmcv.ops.deform_conv import deform_conv2d
from mmcv.utils import TORCH_VERSION, digit_version
@CONV_LAYERS.register_module(name='SAC')
@MODELS.register_module(name='SAC')
class SAConv2d(ConvAWS2d):
"""SAC (Switchable Atrous Convolution)
......
......@@ -15,10 +15,10 @@ import math
import numpy as np
import torch
from mmengine.registry import MODELS
from torch.nn import init
from torch.nn.parameter import Parameter
from ..cnn import CONV_LAYERS
from . import sparse_functional as Fsp
from . import sparse_ops as ops
from .sparse_modules import SparseModule
......@@ -204,7 +204,7 @@ class SparseConvolution(SparseModule):
return out_tensor
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SparseConv2d(SparseConvolution):
def __init__(self,
......@@ -230,7 +230,7 @@ class SparseConv2d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SparseConv3d(SparseConvolution):
def __init__(self,
......@@ -256,7 +256,7 @@ class SparseConv3d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SparseConv4d(SparseConvolution):
def __init__(self,
......@@ -282,7 +282,7 @@ class SparseConv4d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SparseConvTranspose2d(SparseConvolution):
def __init__(self,
......@@ -309,7 +309,7 @@ class SparseConvTranspose2d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SparseConvTranspose3d(SparseConvolution):
def __init__(self,
......@@ -336,7 +336,7 @@ class SparseConvTranspose3d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SparseInverseConv2d(SparseConvolution):
def __init__(self,
......@@ -355,7 +355,7 @@ class SparseInverseConv2d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SparseInverseConv3d(SparseConvolution):
def __init__(self,
......@@ -374,7 +374,7 @@ class SparseInverseConv3d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SubMConv2d(SparseConvolution):
def __init__(self,
......@@ -401,7 +401,7 @@ class SubMConv2d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SubMConv3d(SparseConvolution):
def __init__(self,
......@@ -428,7 +428,7 @@ class SubMConv3d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SubMConv4d(SparseConvolution):
def __init__(self,
......
......@@ -4,12 +4,12 @@ from typing import Optional
import torch
import torch.distributed as dist
import torch.nn.functional as F
from mmengine.registry import MODELS
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
from mmcv.cnn import NORM_LAYERS
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', [
......@@ -159,7 +159,7 @@ class SyncBatchNormFunction(Function):
None, None, None, None, None
@NORM_LAYERS.register_module(name='MMSyncBN')
@MODELS.register_module(name='MMSyncBN')
class SyncBatchNorm(Module):
"""Synchronized Batch Normalization.
......
# Copyright (c) OpenMMLab. All rights reserved.
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
from .base_runner import BaseRunner
from .builder import RUNNERS, build_runner
from .checkpoint import (CheckpointLoader, _load_checkpoint,
......@@ -64,10 +63,10 @@ __all__ = [
'build_optimizer_constructor', 'IterLoader', 'set_random_seed',
'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook',
'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule',
'_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential',
'ModuleDict', 'ModuleList', 'GradientCumulativeOptimizerHook',
'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor',
'SegmindLoggerHook', 'LinearAnnealingMomentumUpdaterHook',
'LinearAnnealingLrUpdaterHook', 'ClearMLLoggerHook'
'allreduce_params', 'LossScaler', 'CheckpointLoader',
'_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook',
'GradientCumulativeOptimizerHook', 'GradientCumulativeFp16OptimizerHook',
'DefaultRunnerConstructor', 'SegmindLoggerHook',
'LinearAnnealingMomentumUpdaterHook', 'LinearAnnealingLrUpdaterHook',
'ClearMLLoggerHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from abc import ABCMeta
from collections import defaultdict
from logging import FileHandler
from typing import Iterable, Optional
import torch.nn as nn
from mmcv.runner.dist_utils import master_only
from mmcv.utils.logging import get_logger, logger_initialized, print_log
class BaseModule(nn.Module, metaclass=ABCMeta):
"""Base module for all modules in openmmlab.
``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional
functionality of parameter initialization. Compared with
``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
- ``init_cfg``: the config to control the initialization.
- ``init_weights``: The function of parameter initialization and recording
initialization information.
- ``_params_init_info``: Used to track the parameter initialization
information. This attribute only exists during executing the
``init_weights``.
Args:
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, init_cfg: Optional[dict] = None):
"""Initialize BaseModule, inherited from `torch.nn.Module`"""
# NOTE init_cfg can be defined in different levels, but init_cfg
# in low levels has a higher priority.
super().__init__()
# define default value of init_cfg instead of hard code
# in init_weights() function
self._is_init = False
self.init_cfg = copy.deepcopy(init_cfg)
# Backward compatibility in derived classes
# if pretrained is not None:
# warnings.warn('DeprecationWarning: pretrained is a deprecated \
# key, please consider using init_cfg')
# self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
@property
def is_init(self) -> bool:
return self._is_init
def init_weights(self) -> None:
"""Initialize the weights."""
is_top_level_module = False
# check if it is top-level module
if not hasattr(self, '_params_init_info'):
# The `_params_init_info` is used to record the initialization
# information of the parameters
# the key should be the obj:`nn.Parameter` of model and the value
# should be a dict containing
# - init_info (str): The string that describes the initialization.
# - tmp_mean_value (FloatTensor): The mean of the parameter,
# which indicates whether the parameter has been modified.
# this attribute would be deleted after all parameters
# is initialized.
self._params_init_info: defaultdict = defaultdict(dict)
is_top_level_module = True
# Initialize the `_params_init_info`,
# When detecting the `tmp_mean_value` of
# the corresponding parameter is changed, update related
# initialization information
for name, param in self.named_parameters():
self._params_init_info[param][
'init_info'] = f'The value is the same before and ' \
f'after calling `init_weights` ' \
f'of {self.__class__.__name__} '
self._params_init_info[param][
'tmp_mean_value'] = param.data.mean()
# pass `params_init_info` to all submodules
# All submodules share the same `params_init_info`,
# so it will be updated when parameters are
# modified at any level of the model.
for sub_module in self.modules():
sub_module._params_init_info = self._params_init_info
# Get the initialized logger, if not exist,
# create a logger named `mmcv`
logger_names = list(logger_initialized.keys())
logger_name = logger_names[0] if logger_names else 'mmcv'
from ..cnn import initialize
from ..cnn.utils.weight_init import update_init_info
module_name = self.__class__.__name__
if not self._is_init:
if self.init_cfg:
print_log(
f'initialize {module_name} with init_cfg {self.init_cfg}',
logger=logger_name)
initialize(self, self.init_cfg)
if isinstance(self.init_cfg, dict):
# prevent the parameters of
# the pre-trained model
# from being overwritten by
# the `init_weights`
if self.init_cfg['type'] == 'Pretrained':
return
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights()
# users may overload the `init_weights`
update_init_info(
m,
init_info=f'Initialized by '
f'user-defined `init_weights`'
f' in {m.__class__.__name__} ')
self._is_init = True
else:
warnings.warn(f'init_weights of {self.__class__.__name__} has '
f'been called more than once.')
if is_top_level_module:
self._dump_init_info(logger_name)
for sub_module in self.modules():
del sub_module._params_init_info
@master_only
def _dump_init_info(self, logger_name: str) -> None:
"""Dump the initialization information to a file named
`initialization.log.json` in workdir.
Args:
logger_name (str): The name of logger.
"""
logger = get_logger(logger_name)
with_file_handler = False
# dump the information to the logger file if there is a `FileHandler`
for handler in logger.handlers:
if isinstance(handler, FileHandler):
handler.stream.write(
'Name of parameter - Initialization information\n')
for name, param in self.named_parameters():
handler.stream.write(
f'\n{name} - {param.shape}: '
f"\n{self._params_init_info[param]['init_info']} \n")
handler.stream.flush()
with_file_handler = True
if not with_file_handler:
for name, param in self.named_parameters():
print_log(
f'\n{name} - {param.shape}: '
f"\n{self._params_init_info[param]['init_info']} \n ",
logger=logger_name)
def __repr__(self):
s = super().__repr__()
if self.init_cfg:
s += f'\ninit_cfg={self.init_cfg}'
return s
class Sequential(BaseModule, nn.Sequential):
"""Sequential module in openmmlab.
Args:
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, *args, init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.Sequential.__init__(self, *args)
class ModuleList(BaseModule, nn.ModuleList):
"""ModuleList in openmmlab.
Args:
modules (iterable, optional): an iterable of modules to add.
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self,
modules: Optional[Iterable] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleList.__init__(self, modules)
class ModuleDict(BaseModule, nn.ModuleDict):
"""ModuleDict in openmmlab.
Args:
modules (dict, optional): a mapping (dictionary) of (string: module)
or an iterable of key-value pairs of type (string, module).
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self,
modules: Optional[dict] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleDict.__init__(self, modules)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment