Unverified Commit 19a02415 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Refactor] Use MODELS registry in mmengine and delete basemodule (#2172)

* change MODELS to mmengine, delete basemodule

* fix unit test

* remove build from cfg

* add comment and rename TARGET_MODELS to registry

* refine cnn docs

* remove unnecessary check

* refine as comment

* refine build_xxx_conv error message

* fix lint

* fix import registry from mmcv

* remove unused file
parent f6fd6c21
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry
CONV_LAYERS = Registry('conv layer')
NORM_LAYERS = Registry('norm layer')
ACTIVATION_LAYERS = Registry('activation layer')
PADDING_LAYERS = Registry('padding layer')
UPSAMPLE_LAYERS = Registry('upsample layer')
PLUGIN_LAYERS = Registry('plugin layer')
DROPOUT_LAYERS = Registry('drop out layers')
POSITIONAL_ENCODING = Registry('position encoding')
ATTENTION = Registry('attention')
FEEDFORWARD_NETWORK = Registry('feed-forward Network')
TRANSFORMER_LAYER = Registry('transformerLayer')
TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.registry import MODELS
from .registry import ACTIVATION_LAYERS
@ACTIVATION_LAYERS.register_module()
@MODELS.register_module()
class Swish(nn.Module):
"""Swish Module.
......
......@@ -7,15 +7,14 @@ from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine import ConfigDict
from mmengine.model import BaseModule, ModuleList, Sequential
from mmengine.registry import MODELS
from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer,
build_norm_layer)
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning,
to_2tuple)
from mmcv.utils import deprecated_api_warning, to_2tuple
from .drop import build_dropout
from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try:
......@@ -37,27 +36,27 @@ except ImportError:
def build_positional_encoding(cfg, default_args=None):
"""Builder for Position Encoding."""
return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
return MODELS.build(cfg, default_args=default_args)
def build_attention(cfg, default_args=None):
"""Builder for attention."""
return build_from_cfg(cfg, ATTENTION, default_args)
return MODELS.build(cfg, default_args=default_args)
def build_feedforward_network(cfg, default_args=None):
"""Builder for feed-forward network (FFN)."""
return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args)
return MODELS.build(cfg, default_args=default_args)
def build_transformer_layer(cfg, default_args=None):
"""Builder for transformer layer."""
return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
return MODELS.build(cfg, default_args=default_args)
def build_transformer_layer_sequence(cfg, default_args=None):
"""Builder for transformer encoder and transformer decoder."""
return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args)
return MODELS.build(cfg, default_args=default_args)
class AdaptivePadding(nn.Module):
......@@ -403,7 +402,7 @@ class PatchMerging(BaseModule):
return x, output_size
@ATTENTION.register_module()
@MODELS.register_module()
class MultiheadAttention(BaseModule):
"""A wrapper for ``torch.nn.MultiheadAttention``.
......@@ -551,7 +550,7 @@ class MultiheadAttention(BaseModule):
return identity + self.dropout_layer(self.proj_drop(out))
@FEEDFORWARD_NETWORK.register_module()
@MODELS.register_module()
class FFN(BaseModule):
"""Implements feed-forward networks (FFNs) with identity connection.
......@@ -628,7 +627,7 @@ class FFN(BaseModule):
return identity + self.dropout_layer(out)
@TRANSFORMER_LAYER.register_module()
@MODELS.register_module()
class BaseTransformerLayer(BaseModule):
"""Base `TransformerLayer` for vision transformer.
......@@ -859,7 +858,7 @@ class BaseTransformerLayer(BaseModule):
return query
@TRANSFORMER_LAYER_SEQUENCE.register_module()
@MODELS.register_module()
class TransformerLayerSequence(BaseModule):
"""Base class for TransformerEncoder and TransformerDecoder in vision
transformer.
......
......@@ -4,15 +4,14 @@ from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model.utils import xavier_init
from mmengine.registry import MODELS
from ..utils import xavier_init
from .registry import UPSAMPLE_LAYERS
MODELS.register_module('nearest', module=nn.Upsample)
MODELS.register_module('bilinear', module=nn.Upsample)
UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
@MODELS.register_module(name='pixel_shuffle')
class PixelShufflePack(nn.Module):
"""Pixel Shuffle upsample layer.
......@@ -76,11 +75,15 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if layer_type not in UPSAMPLE_LAYERS:
raise KeyError(f'Unrecognized upsample type {layer_type}')
else:
upsample = UPSAMPLE_LAYERS.get(layer_type)
# Switch registry to the target scope. If `upsample` cannot be found
# in the registry, fallback to search `upsample` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
upsample = registry.get(layer_type)
if upsample is None:
raise KeyError(f'Cannot find {upsample} in registry under scope '
f'name {registry.scope}')
if upsample is nn.Upsample:
cfg_['mode'] = layer_type
layer = upsample(*args, **kwargs, **cfg_)
......
......@@ -9,10 +9,9 @@ import math
import torch
import torch.nn as nn
from mmengine.registry import MODELS
from torch.nn.modules.utils import _pair, _triple
from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
if torch.__version__ == 'parrots':
TORCH_VERSION = torch.__version__
else:
......@@ -38,7 +37,7 @@ class NewEmptyTensorOp(torch.autograd.Function):
return NewEmptyTensorOp.apply(grad, shape), None
@CONV_LAYERS.register_module('Conv', force=True)
@MODELS.register_module('Conv', force=True)
class Conv2d(nn.Conv2d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -59,7 +58,7 @@ class Conv2d(nn.Conv2d):
return super().forward(x)
@CONV_LAYERS.register_module('Conv3d', force=True)
@MODELS.register_module('Conv3d', force=True)
class Conv3d(nn.Conv3d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -80,9 +79,8 @@ class Conv3d(nn.Conv3d):
return super().forward(x)
@CONV_LAYERS.register_module()
@CONV_LAYERS.register_module('deconv')
@UPSAMPLE_LAYERS.register_module('deconv', force=True)
@MODELS.register_module()
@MODELS.register_module('deconv')
class ConvTranspose2d(nn.ConvTranspose2d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -103,9 +101,8 @@ class ConvTranspose2d(nn.ConvTranspose2d):
return super().forward(x)
@CONV_LAYERS.register_module()
@CONV_LAYERS.register_module('deconv3d')
@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
@MODELS.register_module()
@MODELS.register_module('deconv3d')
class ConvTranspose3d(nn.ConvTranspose3d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
......
# Copyright (c) OpenMMLab. All rights reserved.
from ..runner import Sequential
from ..utils import Registry, build_from_cfg
def build_model_from_cfg(cfg, registry, default_args=None):
"""Build a PyTorch model from config dict(s). Different from
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
Args:
cfg (dict, list[dict]): The config of modules, is is either a config
dict or a list of config dicts. If cfg is a list, a
the built modules will be wrapped with ``nn.Sequential``.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
MODELS = Registry('model', build_func=build_model_from_cfg)
......@@ -4,10 +4,9 @@ from typing import Optional, Sequence, Tuple, Union
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmengine.model.utils import constant_init, kaiming_init
from torch import Tensor
from .utils import constant_init, kaiming_init
def conv3x3(in_planes: int,
out_planes: int,
......
......@@ -2,18 +2,7 @@
from .flops_counter import get_model_complexity_info
from .fuse_conv_bn import fuse_conv_bn
from .sync_bn import revert_sync_batchnorm
from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
KaimingInit, NormalInit, PretrainedInit,
TruncNormalInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init,
constant_init, initialize, kaiming_init, normal_init,
trunc_normal_init, uniform_init, xavier_init)
__all__ = [
'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'Caffe2XavierInit', 'revert_sync_batchnorm'
'get_model_complexity_info', 'fuse_conv_bn', 'revert_sync_batchnorm'
]
This diff is collapsed.
......@@ -3,10 +3,9 @@ import logging
from typing import List, Optional, Sequence, Tuple, Union
import torch.nn as nn
from mmengine.model.utils import constant_init, kaiming_init, normal_init
from torch import Tensor
from .utils import constant_init, kaiming_init, normal_init
def conv3x3(in_planes: int, out_planes: int, dilation: int = 1) -> nn.Module:
"""3x3 convolution with padding."""
......
......@@ -4,11 +4,12 @@ from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model.utils import normal_init, xavier_init
from mmengine.registry import MODELS
from torch import Tensor
from torch.autograd import Function
from torch.nn.modules.module import Module
from ..cnn import UPSAMPLE_LAYERS, normal_init, xavier_init
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', [
......@@ -219,7 +220,7 @@ class CARAFE(Module):
self.scale_factor)
@UPSAMPLE_LAYERS.register_module(name='carafe')
@MODELS.register_module(name='carafe')
class CARAFEPack(nn.Module):
"""A unified package of CARAFE upsampler that contains: 1) channel
compressor 2) content encoder 3) CARAFE op.
......
......@@ -2,8 +2,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.registry import MODELS
from mmcv.cnn import PLUGIN_LAYERS, Scale
from mmcv.cnn import Scale
def NEG_INF_DIAG(n: int, device: torch.device) -> torch.Tensor:
......@@ -15,7 +16,7 @@ def NEG_INF_DIAG(n: int, device: torch.device) -> torch.Tensor:
return torch.diag(torch.tensor(float('-inf')).to(device).repeat(n), 0)
@PLUGIN_LAYERS.register_module()
@MODELS.register_module()
class CrissCrossAttention(nn.Module):
"""Criss-Cross Attention Module.
......
......@@ -4,14 +4,15 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine import print_log
from mmengine.registry import MODELS
from torch import Tensor
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single
from mmcv.utils import deprecated_api_warning
from ..cnn import CONV_LAYERS
from ..utils import ext_loader, print_log
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', [
'deform_conv_forward', 'deform_conv_backward_input',
......@@ -330,7 +331,7 @@ class DeformConv2d(nn.Module):
return s
@CONV_LAYERS.register_module('DCN')
@MODELS.register_module('DCN')
class DeformConv2dPack(DeformConv2d):
"""A Deformable Conv Encapsulation that acts as normal Conv layers.
......
......@@ -4,13 +4,14 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from mmengine import print_log
from mmengine.registry import MODELS
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single
from mmcv.utils import deprecated_api_warning
from ..cnn import CONV_LAYERS
from ..utils import ext_loader, print_log
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext',
......@@ -208,7 +209,7 @@ class ModulatedDeformConv2d(nn.Module):
self.deform_groups)
@CONV_LAYERS.register_module('DCNv2')
@MODELS.register_module('DCNv2')
class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv
layers.
......
......@@ -6,13 +6,13 @@ from typing import Optional, no_type_check
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmengine.model.utils import constant_init, xavier_init
from mmengine.registry import MODELS
from torch.autograd.function import Function, once_differentiable
import mmcv
from mmcv import deprecated_api_warning
from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.runner import BaseModule
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
......@@ -156,7 +156,7 @@ def multi_scale_deformable_attn_pytorch(
return output.transpose(1, 2).contiguous()
@ATTENTION.register_module()
@MODELS.register_module()
class MultiScaleDeformableAttention(BaseModule):
"""An attention module used in Deformable-Detr.
......
......@@ -2,13 +2,15 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model.utils import constant_init
from mmengine.registry import MODELS
from mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init
from mmcv.cnn import ConvAWS2d
from mmcv.ops.deform_conv import deform_conv2d
from mmcv.utils import TORCH_VERSION, digit_version
@CONV_LAYERS.register_module(name='SAC')
@MODELS.register_module(name='SAC')
class SAConv2d(ConvAWS2d):
"""SAC (Switchable Atrous Convolution)
......
......@@ -15,10 +15,10 @@ import math
import numpy as np
import torch
from mmengine.registry import MODELS
from torch.nn import init
from torch.nn.parameter import Parameter
from ..cnn import CONV_LAYERS
from . import sparse_functional as Fsp
from . import sparse_ops as ops
from .sparse_modules import SparseModule
......@@ -204,7 +204,7 @@ class SparseConvolution(SparseModule):
return out_tensor
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SparseConv2d(SparseConvolution):
def __init__(self,
......@@ -230,7 +230,7 @@ class SparseConv2d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SparseConv3d(SparseConvolution):
def __init__(self,
......@@ -256,7 +256,7 @@ class SparseConv3d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SparseConv4d(SparseConvolution):
def __init__(self,
......@@ -282,7 +282,7 @@ class SparseConv4d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SparseConvTranspose2d(SparseConvolution):
def __init__(self,
......@@ -309,7 +309,7 @@ class SparseConvTranspose2d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SparseConvTranspose3d(SparseConvolution):
def __init__(self,
......@@ -336,7 +336,7 @@ class SparseConvTranspose3d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SparseInverseConv2d(SparseConvolution):
def __init__(self,
......@@ -355,7 +355,7 @@ class SparseInverseConv2d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SparseInverseConv3d(SparseConvolution):
def __init__(self,
......@@ -374,7 +374,7 @@ class SparseInverseConv3d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SubMConv2d(SparseConvolution):
def __init__(self,
......@@ -401,7 +401,7 @@ class SubMConv2d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SubMConv3d(SparseConvolution):
def __init__(self,
......@@ -428,7 +428,7 @@ class SubMConv3d(SparseConvolution):
indice_key=indice_key)
@CONV_LAYERS.register_module()
@MODELS.register_module()
class SubMConv4d(SparseConvolution):
def __init__(self,
......
......@@ -4,12 +4,12 @@ from typing import Optional
import torch
import torch.distributed as dist
import torch.nn.functional as F
from mmengine.registry import MODELS
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
from mmcv.cnn import NORM_LAYERS
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', [
......@@ -159,7 +159,7 @@ class SyncBatchNormFunction(Function):
None, None, None, None, None
@NORM_LAYERS.register_module(name='MMSyncBN')
@MODELS.register_module(name='MMSyncBN')
class SyncBatchNorm(Module):
"""Synchronized Batch Normalization.
......
# Copyright (c) OpenMMLab. All rights reserved.
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
from .base_runner import BaseRunner
from .builder import RUNNERS, build_runner
from .checkpoint import (CheckpointLoader, _load_checkpoint,
......@@ -64,10 +63,10 @@ __all__ = [
'build_optimizer_constructor', 'IterLoader', 'set_random_seed',
'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook',
'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule',
'_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential',
'ModuleDict', 'ModuleList', 'GradientCumulativeOptimizerHook',
'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor',
'SegmindLoggerHook', 'LinearAnnealingMomentumUpdaterHook',
'LinearAnnealingLrUpdaterHook', 'ClearMLLoggerHook'
'allreduce_params', 'LossScaler', 'CheckpointLoader',
'_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook',
'GradientCumulativeOptimizerHook', 'GradientCumulativeFp16OptimizerHook',
'DefaultRunnerConstructor', 'SegmindLoggerHook',
'LinearAnnealingMomentumUpdaterHook', 'LinearAnnealingLrUpdaterHook',
'ClearMLLoggerHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from abc import ABCMeta
from collections import defaultdict
from logging import FileHandler
from typing import Iterable, Optional
import torch.nn as nn
from mmcv.runner.dist_utils import master_only
from mmcv.utils.logging import get_logger, logger_initialized, print_log
class BaseModule(nn.Module, metaclass=ABCMeta):
"""Base module for all modules in openmmlab.
``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional
functionality of parameter initialization. Compared with
``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
- ``init_cfg``: the config to control the initialization.
- ``init_weights``: The function of parameter initialization and recording
initialization information.
- ``_params_init_info``: Used to track the parameter initialization
information. This attribute only exists during executing the
``init_weights``.
Args:
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, init_cfg: Optional[dict] = None):
"""Initialize BaseModule, inherited from `torch.nn.Module`"""
# NOTE init_cfg can be defined in different levels, but init_cfg
# in low levels has a higher priority.
super().__init__()
# define default value of init_cfg instead of hard code
# in init_weights() function
self._is_init = False
self.init_cfg = copy.deepcopy(init_cfg)
# Backward compatibility in derived classes
# if pretrained is not None:
# warnings.warn('DeprecationWarning: pretrained is a deprecated \
# key, please consider using init_cfg')
# self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
@property
def is_init(self) -> bool:
return self._is_init
def init_weights(self) -> None:
"""Initialize the weights."""
is_top_level_module = False
# check if it is top-level module
if not hasattr(self, '_params_init_info'):
# The `_params_init_info` is used to record the initialization
# information of the parameters
# the key should be the obj:`nn.Parameter` of model and the value
# should be a dict containing
# - init_info (str): The string that describes the initialization.
# - tmp_mean_value (FloatTensor): The mean of the parameter,
# which indicates whether the parameter has been modified.
# this attribute would be deleted after all parameters
# is initialized.
self._params_init_info: defaultdict = defaultdict(dict)
is_top_level_module = True
# Initialize the `_params_init_info`,
# When detecting the `tmp_mean_value` of
# the corresponding parameter is changed, update related
# initialization information
for name, param in self.named_parameters():
self._params_init_info[param][
'init_info'] = f'The value is the same before and ' \
f'after calling `init_weights` ' \
f'of {self.__class__.__name__} '
self._params_init_info[param][
'tmp_mean_value'] = param.data.mean()
# pass `params_init_info` to all submodules
# All submodules share the same `params_init_info`,
# so it will be updated when parameters are
# modified at any level of the model.
for sub_module in self.modules():
sub_module._params_init_info = self._params_init_info
# Get the initialized logger, if not exist,
# create a logger named `mmcv`
logger_names = list(logger_initialized.keys())
logger_name = logger_names[0] if logger_names else 'mmcv'
from ..cnn import initialize
from ..cnn.utils.weight_init import update_init_info
module_name = self.__class__.__name__
if not self._is_init:
if self.init_cfg:
print_log(
f'initialize {module_name} with init_cfg {self.init_cfg}',
logger=logger_name)
initialize(self, self.init_cfg)
if isinstance(self.init_cfg, dict):
# prevent the parameters of
# the pre-trained model
# from being overwritten by
# the `init_weights`
if self.init_cfg['type'] == 'Pretrained':
return
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights()
# users may overload the `init_weights`
update_init_info(
m,
init_info=f'Initialized by '
f'user-defined `init_weights`'
f' in {m.__class__.__name__} ')
self._is_init = True
else:
warnings.warn(f'init_weights of {self.__class__.__name__} has '
f'been called more than once.')
if is_top_level_module:
self._dump_init_info(logger_name)
for sub_module in self.modules():
del sub_module._params_init_info
@master_only
def _dump_init_info(self, logger_name: str) -> None:
"""Dump the initialization information to a file named
`initialization.log.json` in workdir.
Args:
logger_name (str): The name of logger.
"""
logger = get_logger(logger_name)
with_file_handler = False
# dump the information to the logger file if there is a `FileHandler`
for handler in logger.handlers:
if isinstance(handler, FileHandler):
handler.stream.write(
'Name of parameter - Initialization information\n')
for name, param in self.named_parameters():
handler.stream.write(
f'\n{name} - {param.shape}: '
f"\n{self._params_init_info[param]['init_info']} \n")
handler.stream.flush()
with_file_handler = True
if not with_file_handler:
for name, param in self.named_parameters():
print_log(
f'\n{name} - {param.shape}: '
f"\n{self._params_init_info[param]['init_info']} \n ",
logger=logger_name)
def __repr__(self):
s = super().__repr__()
if self.init_cfg:
s += f'\ninit_cfg={self.init_cfg}'
return s
class Sequential(BaseModule, nn.Sequential):
"""Sequential module in openmmlab.
Args:
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, *args, init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.Sequential.__init__(self, *args)
class ModuleList(BaseModule, nn.ModuleList):
"""ModuleList in openmmlab.
Args:
modules (iterable, optional): an iterable of modules to add.
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self,
modules: Optional[Iterable] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleList.__init__(self, modules)
class ModuleDict(BaseModule, nn.ModuleDict):
"""ModuleDict in openmmlab.
Args:
modules (dict, optional): a mapping (dictionary) of (string: module)
or an iterable of key-value pairs of type (string, module).
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self,
modules: Optional[dict] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleDict.__init__(self, modules)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment