Unverified Commit 19a02415 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Refactor] Use MODELS registry in mmengine and delete basemodule (#2172)

* change MODELS to mmengine, delete basemodule

* fix unit test

* remove build from cfg

* add comment and rename TARGET_MODELS to registry

* refine cnn docs

* remove unnecessary check

* refine as comment

* refine build_xxx_conv error message

* fix lint

* fix import registry from mmcv

* remove unused file
parent f6fd6c21
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry
CONV_LAYERS = Registry('conv layer')
NORM_LAYERS = Registry('norm layer')
ACTIVATION_LAYERS = Registry('activation layer')
PADDING_LAYERS = Registry('padding layer')
UPSAMPLE_LAYERS = Registry('upsample layer')
PLUGIN_LAYERS = Registry('plugin layer')
DROPOUT_LAYERS = Registry('drop out layers')
POSITIONAL_ENCODING = Registry('position encoding')
ATTENTION = Registry('attention')
FEEDFORWARD_NETWORK = Registry('feed-forward Network')
TRANSFORMER_LAYER = Registry('transformerLayer')
TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.registry import MODELS
from .registry import ACTIVATION_LAYERS
@MODELS.register_module()
@ACTIVATION_LAYERS.register_module()
class Swish(nn.Module): class Swish(nn.Module):
"""Swish Module. """Swish Module.
......
...@@ -7,15 +7,14 @@ from typing import Sequence ...@@ -7,15 +7,14 @@ from typing import Sequence
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmengine import ConfigDict
from mmengine.model import BaseModule, ModuleList, Sequential
from mmengine.registry import MODELS
from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer, from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer,
build_norm_layer) build_norm_layer)
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential from mmcv.utils import deprecated_api_warning, to_2tuple
from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning,
to_2tuple)
from .drop import build_dropout from .drop import build_dropout
from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file # Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try: try:
...@@ -37,27 +36,27 @@ except ImportError: ...@@ -37,27 +36,27 @@ except ImportError:
def build_positional_encoding(cfg, default_args=None): def build_positional_encoding(cfg, default_args=None):
"""Builder for Position Encoding.""" """Builder for Position Encoding."""
return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args) return MODELS.build(cfg, default_args=default_args)
def build_attention(cfg, default_args=None): def build_attention(cfg, default_args=None):
"""Builder for attention.""" """Builder for attention."""
return build_from_cfg(cfg, ATTENTION, default_args) return MODELS.build(cfg, default_args=default_args)
def build_feedforward_network(cfg, default_args=None): def build_feedforward_network(cfg, default_args=None):
"""Builder for feed-forward network (FFN).""" """Builder for feed-forward network (FFN)."""
return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args) return MODELS.build(cfg, default_args=default_args)
def build_transformer_layer(cfg, default_args=None): def build_transformer_layer(cfg, default_args=None):
"""Builder for transformer layer.""" """Builder for transformer layer."""
return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args) return MODELS.build(cfg, default_args=default_args)
def build_transformer_layer_sequence(cfg, default_args=None): def build_transformer_layer_sequence(cfg, default_args=None):
"""Builder for transformer encoder and transformer decoder.""" """Builder for transformer encoder and transformer decoder."""
return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args) return MODELS.build(cfg, default_args=default_args)
class AdaptivePadding(nn.Module): class AdaptivePadding(nn.Module):
...@@ -403,7 +402,7 @@ class PatchMerging(BaseModule): ...@@ -403,7 +402,7 @@ class PatchMerging(BaseModule):
return x, output_size return x, output_size
@ATTENTION.register_module() @MODELS.register_module()
class MultiheadAttention(BaseModule): class MultiheadAttention(BaseModule):
"""A wrapper for ``torch.nn.MultiheadAttention``. """A wrapper for ``torch.nn.MultiheadAttention``.
...@@ -551,7 +550,7 @@ class MultiheadAttention(BaseModule): ...@@ -551,7 +550,7 @@ class MultiheadAttention(BaseModule):
return identity + self.dropout_layer(self.proj_drop(out)) return identity + self.dropout_layer(self.proj_drop(out))
@FEEDFORWARD_NETWORK.register_module() @MODELS.register_module()
class FFN(BaseModule): class FFN(BaseModule):
"""Implements feed-forward networks (FFNs) with identity connection. """Implements feed-forward networks (FFNs) with identity connection.
...@@ -628,7 +627,7 @@ class FFN(BaseModule): ...@@ -628,7 +627,7 @@ class FFN(BaseModule):
return identity + self.dropout_layer(out) return identity + self.dropout_layer(out)
@TRANSFORMER_LAYER.register_module() @MODELS.register_module()
class BaseTransformerLayer(BaseModule): class BaseTransformerLayer(BaseModule):
"""Base `TransformerLayer` for vision transformer. """Base `TransformerLayer` for vision transformer.
...@@ -859,7 +858,7 @@ class BaseTransformerLayer(BaseModule): ...@@ -859,7 +858,7 @@ class BaseTransformerLayer(BaseModule):
return query return query
@TRANSFORMER_LAYER_SEQUENCE.register_module() @MODELS.register_module()
class TransformerLayerSequence(BaseModule): class TransformerLayerSequence(BaseModule):
"""Base class for TransformerEncoder and TransformerDecoder in vision """Base class for TransformerEncoder and TransformerDecoder in vision
transformer. transformer.
......
...@@ -4,15 +4,14 @@ from typing import Dict ...@@ -4,15 +4,14 @@ from typing import Dict
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmengine.model.utils import xavier_init
from mmengine.registry import MODELS
from ..utils import xavier_init MODELS.register_module('nearest', module=nn.Upsample)
from .registry import UPSAMPLE_LAYERS MODELS.register_module('bilinear', module=nn.Upsample)
UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
@MODELS.register_module(name='pixel_shuffle')
@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
class PixelShufflePack(nn.Module): class PixelShufflePack(nn.Module):
"""Pixel Shuffle upsample layer. """Pixel Shuffle upsample layer.
...@@ -76,11 +75,15 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module: ...@@ -76,11 +75,15 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
cfg_ = cfg.copy() cfg_ = cfg.copy()
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
if layer_type not in UPSAMPLE_LAYERS:
raise KeyError(f'Unrecognized upsample type {layer_type}')
else:
upsample = UPSAMPLE_LAYERS.get(layer_type)
# Switch registry to the target scope. If `upsample` cannot be found
# in the registry, fallback to search `upsample` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
upsample = registry.get(layer_type)
if upsample is None:
raise KeyError(f'Cannot find {upsample} in registry under scope '
f'name {registry.scope}')
if upsample is nn.Upsample: if upsample is nn.Upsample:
cfg_['mode'] = layer_type cfg_['mode'] = layer_type
layer = upsample(*args, **kwargs, **cfg_) layer = upsample(*args, **kwargs, **cfg_)
......
...@@ -9,10 +9,9 @@ import math ...@@ -9,10 +9,9 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.registry import MODELS
from torch.nn.modules.utils import _pair, _triple from torch.nn.modules.utils import _pair, _triple
from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
if torch.__version__ == 'parrots': if torch.__version__ == 'parrots':
TORCH_VERSION = torch.__version__ TORCH_VERSION = torch.__version__
else: else:
...@@ -38,7 +37,7 @@ class NewEmptyTensorOp(torch.autograd.Function): ...@@ -38,7 +37,7 @@ class NewEmptyTensorOp(torch.autograd.Function):
return NewEmptyTensorOp.apply(grad, shape), None return NewEmptyTensorOp.apply(grad, shape), None
@CONV_LAYERS.register_module('Conv', force=True) @MODELS.register_module('Conv', force=True)
class Conv2d(nn.Conv2d): class Conv2d(nn.Conv2d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -59,7 +58,7 @@ class Conv2d(nn.Conv2d): ...@@ -59,7 +58,7 @@ class Conv2d(nn.Conv2d):
return super().forward(x) return super().forward(x)
@CONV_LAYERS.register_module('Conv3d', force=True) @MODELS.register_module('Conv3d', force=True)
class Conv3d(nn.Conv3d): class Conv3d(nn.Conv3d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -80,9 +79,8 @@ class Conv3d(nn.Conv3d): ...@@ -80,9 +79,8 @@ class Conv3d(nn.Conv3d):
return super().forward(x) return super().forward(x)
@CONV_LAYERS.register_module() @MODELS.register_module()
@CONV_LAYERS.register_module('deconv') @MODELS.register_module('deconv')
@UPSAMPLE_LAYERS.register_module('deconv', force=True)
class ConvTranspose2d(nn.ConvTranspose2d): class ConvTranspose2d(nn.ConvTranspose2d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -103,9 +101,8 @@ class ConvTranspose2d(nn.ConvTranspose2d): ...@@ -103,9 +101,8 @@ class ConvTranspose2d(nn.ConvTranspose2d):
return super().forward(x) return super().forward(x)
@CONV_LAYERS.register_module() @MODELS.register_module()
@CONV_LAYERS.register_module('deconv3d') @MODELS.register_module('deconv3d')
@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
class ConvTranspose3d(nn.ConvTranspose3d): class ConvTranspose3d(nn.ConvTranspose3d):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
......
# Copyright (c) OpenMMLab. All rights reserved.
from ..runner import Sequential
from ..utils import Registry, build_from_cfg
def build_model_from_cfg(cfg, registry, default_args=None):
"""Build a PyTorch model from config dict(s). Different from
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
Args:
cfg (dict, list[dict]): The config of modules, is is either a config
dict or a list of config dicts. If cfg is a list, a
the built modules will be wrapped with ``nn.Sequential``.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
MODELS = Registry('model', build_func=build_model_from_cfg)
...@@ -4,10 +4,9 @@ from typing import Optional, Sequence, Tuple, Union ...@@ -4,10 +4,9 @@ from typing import Optional, Sequence, Tuple, Union
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from mmengine.model.utils import constant_init, kaiming_init
from torch import Tensor from torch import Tensor
from .utils import constant_init, kaiming_init
def conv3x3(in_planes: int, def conv3x3(in_planes: int,
out_planes: int, out_planes: int,
......
...@@ -2,18 +2,7 @@ ...@@ -2,18 +2,7 @@
from .flops_counter import get_model_complexity_info from .flops_counter import get_model_complexity_info
from .fuse_conv_bn import fuse_conv_bn from .fuse_conv_bn import fuse_conv_bn
from .sync_bn import revert_sync_batchnorm from .sync_bn import revert_sync_batchnorm
from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
KaimingInit, NormalInit, PretrainedInit,
TruncNormalInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init,
constant_init, initialize, kaiming_init, normal_init,
trunc_normal_init, uniform_init, xavier_init)
__all__ = [ __all__ = [
'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init', 'get_model_complexity_info', 'fuse_conv_bn', 'revert_sync_batchnorm'
'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'Caffe2XavierInit', 'revert_sync_batchnorm'
] ]
This diff is collapsed.
...@@ -3,10 +3,9 @@ import logging ...@@ -3,10 +3,9 @@ import logging
from typing import List, Optional, Sequence, Tuple, Union from typing import List, Optional, Sequence, Tuple, Union
import torch.nn as nn import torch.nn as nn
from mmengine.model.utils import constant_init, kaiming_init, normal_init
from torch import Tensor from torch import Tensor
from .utils import constant_init, kaiming_init, normal_init
def conv3x3(in_planes: int, out_planes: int, dilation: int = 1) -> nn.Module: def conv3x3(in_planes: int, out_planes: int, dilation: int = 1) -> nn.Module:
"""3x3 convolution with padding.""" """3x3 convolution with padding."""
......
...@@ -4,11 +4,12 @@ from typing import Tuple ...@@ -4,11 +4,12 @@ from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmengine.model.utils import normal_init, xavier_init
from mmengine.registry import MODELS
from torch import Tensor from torch import Tensor
from torch.autograd import Function from torch.autograd import Function
from torch.nn.modules.module import Module from torch.nn.modules.module import Module
from ..cnn import UPSAMPLE_LAYERS, normal_init, xavier_init
from ..utils import ext_loader from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', [ ext_module = ext_loader.load_ext('_ext', [
...@@ -219,7 +220,7 @@ class CARAFE(Module): ...@@ -219,7 +220,7 @@ class CARAFE(Module):
self.scale_factor) self.scale_factor)
@UPSAMPLE_LAYERS.register_module(name='carafe') @MODELS.register_module(name='carafe')
class CARAFEPack(nn.Module): class CARAFEPack(nn.Module):
"""A unified package of CARAFE upsampler that contains: 1) channel """A unified package of CARAFE upsampler that contains: 1) channel
compressor 2) content encoder 3) CARAFE op. compressor 2) content encoder 3) CARAFE op.
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmengine.registry import MODELS
from mmcv.cnn import PLUGIN_LAYERS, Scale from mmcv.cnn import Scale
def NEG_INF_DIAG(n: int, device: torch.device) -> torch.Tensor: def NEG_INF_DIAG(n: int, device: torch.device) -> torch.Tensor:
...@@ -15,7 +16,7 @@ def NEG_INF_DIAG(n: int, device: torch.device) -> torch.Tensor: ...@@ -15,7 +16,7 @@ def NEG_INF_DIAG(n: int, device: torch.device) -> torch.Tensor:
return torch.diag(torch.tensor(float('-inf')).to(device).repeat(n), 0) return torch.diag(torch.tensor(float('-inf')).to(device).repeat(n), 0)
@PLUGIN_LAYERS.register_module() @MODELS.register_module()
class CrissCrossAttention(nn.Module): class CrissCrossAttention(nn.Module):
"""Criss-Cross Attention Module. """Criss-Cross Attention Module.
......
...@@ -4,14 +4,15 @@ from typing import Optional, Tuple, Union ...@@ -4,14 +4,15 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmengine import print_log
from mmengine.registry import MODELS
from torch import Tensor from torch import Tensor
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single from torch.nn.modules.utils import _pair, _single
from mmcv.utils import deprecated_api_warning from mmcv.utils import deprecated_api_warning
from ..cnn import CONV_LAYERS from ..utils import ext_loader
from ..utils import ext_loader, print_log
ext_module = ext_loader.load_ext('_ext', [ ext_module = ext_loader.load_ext('_ext', [
'deform_conv_forward', 'deform_conv_backward_input', 'deform_conv_forward', 'deform_conv_backward_input',
...@@ -330,7 +331,7 @@ class DeformConv2d(nn.Module): ...@@ -330,7 +331,7 @@ class DeformConv2d(nn.Module):
return s return s
@CONV_LAYERS.register_module('DCN') @MODELS.register_module('DCN')
class DeformConv2dPack(DeformConv2d): class DeformConv2dPack(DeformConv2d):
"""A Deformable Conv Encapsulation that acts as normal Conv layers. """A Deformable Conv Encapsulation that acts as normal Conv layers.
......
...@@ -4,13 +4,14 @@ from typing import Optional, Tuple, Union ...@@ -4,13 +4,14 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine import print_log
from mmengine.registry import MODELS
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single from torch.nn.modules.utils import _pair, _single
from mmcv.utils import deprecated_api_warning from mmcv.utils import deprecated_api_warning
from ..cnn import CONV_LAYERS from ..utils import ext_loader
from ..utils import ext_loader, print_log
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
'_ext', '_ext',
...@@ -208,7 +209,7 @@ class ModulatedDeformConv2d(nn.Module): ...@@ -208,7 +209,7 @@ class ModulatedDeformConv2d(nn.Module):
self.deform_groups) self.deform_groups)
@CONV_LAYERS.register_module('DCNv2') @MODELS.register_module('DCNv2')
class ModulatedDeformConv2dPack(ModulatedDeformConv2d): class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv """A ModulatedDeformable Conv Encapsulation that acts as normal Conv
layers. layers.
......
...@@ -6,13 +6,13 @@ from typing import Optional, no_type_check ...@@ -6,13 +6,13 @@ from typing import Optional, no_type_check
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmengine.model import BaseModule
from mmengine.model.utils import constant_init, xavier_init
from mmengine.registry import MODELS
from torch.autograd.function import Function, once_differentiable from torch.autograd.function import Function, once_differentiable
import mmcv import mmcv
from mmcv import deprecated_api_warning from mmcv import deprecated_api_warning
from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.runner import BaseModule
from ..utils import ext_loader from ..utils import ext_loader
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
...@@ -156,7 +156,7 @@ def multi_scale_deformable_attn_pytorch( ...@@ -156,7 +156,7 @@ def multi_scale_deformable_attn_pytorch(
return output.transpose(1, 2).contiguous() return output.transpose(1, 2).contiguous()
@ATTENTION.register_module() @MODELS.register_module()
class MultiScaleDeformableAttention(BaseModule): class MultiScaleDeformableAttention(BaseModule):
"""An attention module used in Deformable-Detr. """An attention module used in Deformable-Detr.
......
...@@ -2,13 +2,15 @@ ...@@ -2,13 +2,15 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmengine.model.utils import constant_init
from mmengine.registry import MODELS
from mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init from mmcv.cnn import ConvAWS2d
from mmcv.ops.deform_conv import deform_conv2d from mmcv.ops.deform_conv import deform_conv2d
from mmcv.utils import TORCH_VERSION, digit_version from mmcv.utils import TORCH_VERSION, digit_version
@CONV_LAYERS.register_module(name='SAC') @MODELS.register_module(name='SAC')
class SAConv2d(ConvAWS2d): class SAConv2d(ConvAWS2d):
"""SAC (Switchable Atrous Convolution) """SAC (Switchable Atrous Convolution)
......
...@@ -15,10 +15,10 @@ import math ...@@ -15,10 +15,10 @@ import math
import numpy as np import numpy as np
import torch import torch
from mmengine.registry import MODELS
from torch.nn import init from torch.nn import init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from ..cnn import CONV_LAYERS
from . import sparse_functional as Fsp from . import sparse_functional as Fsp
from . import sparse_ops as ops from . import sparse_ops as ops
from .sparse_modules import SparseModule from .sparse_modules import SparseModule
...@@ -204,7 +204,7 @@ class SparseConvolution(SparseModule): ...@@ -204,7 +204,7 @@ class SparseConvolution(SparseModule):
return out_tensor return out_tensor
@CONV_LAYERS.register_module() @MODELS.register_module()
class SparseConv2d(SparseConvolution): class SparseConv2d(SparseConvolution):
def __init__(self, def __init__(self,
...@@ -230,7 +230,7 @@ class SparseConv2d(SparseConvolution): ...@@ -230,7 +230,7 @@ class SparseConv2d(SparseConvolution):
indice_key=indice_key) indice_key=indice_key)
@CONV_LAYERS.register_module() @MODELS.register_module()
class SparseConv3d(SparseConvolution): class SparseConv3d(SparseConvolution):
def __init__(self, def __init__(self,
...@@ -256,7 +256,7 @@ class SparseConv3d(SparseConvolution): ...@@ -256,7 +256,7 @@ class SparseConv3d(SparseConvolution):
indice_key=indice_key) indice_key=indice_key)
@CONV_LAYERS.register_module() @MODELS.register_module()
class SparseConv4d(SparseConvolution): class SparseConv4d(SparseConvolution):
def __init__(self, def __init__(self,
...@@ -282,7 +282,7 @@ class SparseConv4d(SparseConvolution): ...@@ -282,7 +282,7 @@ class SparseConv4d(SparseConvolution):
indice_key=indice_key) indice_key=indice_key)
@CONV_LAYERS.register_module() @MODELS.register_module()
class SparseConvTranspose2d(SparseConvolution): class SparseConvTranspose2d(SparseConvolution):
def __init__(self, def __init__(self,
...@@ -309,7 +309,7 @@ class SparseConvTranspose2d(SparseConvolution): ...@@ -309,7 +309,7 @@ class SparseConvTranspose2d(SparseConvolution):
indice_key=indice_key) indice_key=indice_key)
@CONV_LAYERS.register_module() @MODELS.register_module()
class SparseConvTranspose3d(SparseConvolution): class SparseConvTranspose3d(SparseConvolution):
def __init__(self, def __init__(self,
...@@ -336,7 +336,7 @@ class SparseConvTranspose3d(SparseConvolution): ...@@ -336,7 +336,7 @@ class SparseConvTranspose3d(SparseConvolution):
indice_key=indice_key) indice_key=indice_key)
@CONV_LAYERS.register_module() @MODELS.register_module()
class SparseInverseConv2d(SparseConvolution): class SparseInverseConv2d(SparseConvolution):
def __init__(self, def __init__(self,
...@@ -355,7 +355,7 @@ class SparseInverseConv2d(SparseConvolution): ...@@ -355,7 +355,7 @@ class SparseInverseConv2d(SparseConvolution):
indice_key=indice_key) indice_key=indice_key)
@CONV_LAYERS.register_module() @MODELS.register_module()
class SparseInverseConv3d(SparseConvolution): class SparseInverseConv3d(SparseConvolution):
def __init__(self, def __init__(self,
...@@ -374,7 +374,7 @@ class SparseInverseConv3d(SparseConvolution): ...@@ -374,7 +374,7 @@ class SparseInverseConv3d(SparseConvolution):
indice_key=indice_key) indice_key=indice_key)
@CONV_LAYERS.register_module() @MODELS.register_module()
class SubMConv2d(SparseConvolution): class SubMConv2d(SparseConvolution):
def __init__(self, def __init__(self,
...@@ -401,7 +401,7 @@ class SubMConv2d(SparseConvolution): ...@@ -401,7 +401,7 @@ class SubMConv2d(SparseConvolution):
indice_key=indice_key) indice_key=indice_key)
@CONV_LAYERS.register_module() @MODELS.register_module()
class SubMConv3d(SparseConvolution): class SubMConv3d(SparseConvolution):
def __init__(self, def __init__(self,
...@@ -428,7 +428,7 @@ class SubMConv3d(SparseConvolution): ...@@ -428,7 +428,7 @@ class SubMConv3d(SparseConvolution):
indice_key=indice_key) indice_key=indice_key)
@CONV_LAYERS.register_module() @MODELS.register_module()
class SubMConv4d(SparseConvolution): class SubMConv4d(SparseConvolution):
def __init__(self, def __init__(self,
......
...@@ -4,12 +4,12 @@ from typing import Optional ...@@ -4,12 +4,12 @@ from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from mmengine.registry import MODELS
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from torch.nn.modules.module import Module from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from mmcv.cnn import NORM_LAYERS
from ..utils import ext_loader from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', [ ext_module = ext_loader.load_ext('_ext', [
...@@ -159,7 +159,7 @@ class SyncBatchNormFunction(Function): ...@@ -159,7 +159,7 @@ class SyncBatchNormFunction(Function):
None, None, None, None, None None, None, None, None, None
@NORM_LAYERS.register_module(name='MMSyncBN') @MODELS.register_module(name='MMSyncBN')
class SyncBatchNorm(Module): class SyncBatchNorm(Module):
"""Synchronized Batch Normalization. """Synchronized Batch Normalization.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
from .base_runner import BaseRunner from .base_runner import BaseRunner
from .builder import RUNNERS, build_runner from .builder import RUNNERS, build_runner
from .checkpoint import (CheckpointLoader, _load_checkpoint, from .checkpoint import (CheckpointLoader, _load_checkpoint,
...@@ -64,10 +63,10 @@ __all__ = [ ...@@ -64,10 +63,10 @@ __all__ = [
'build_optimizer_constructor', 'IterLoader', 'set_random_seed', 'build_optimizer_constructor', 'IterLoader', 'set_random_seed',
'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook', 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook',
'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads', 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule', 'allreduce_params', 'LossScaler', 'CheckpointLoader',
'_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential', '_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook',
'ModuleDict', 'ModuleList', 'GradientCumulativeOptimizerHook', 'GradientCumulativeOptimizerHook', 'GradientCumulativeFp16OptimizerHook',
'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor', 'DefaultRunnerConstructor', 'SegmindLoggerHook',
'SegmindLoggerHook', 'LinearAnnealingMomentumUpdaterHook', 'LinearAnnealingMomentumUpdaterHook', 'LinearAnnealingLrUpdaterHook',
'LinearAnnealingLrUpdaterHook', 'ClearMLLoggerHook' 'ClearMLLoggerHook'
] ]
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from abc import ABCMeta
from collections import defaultdict
from logging import FileHandler
from typing import Iterable, Optional
import torch.nn as nn
from mmcv.runner.dist_utils import master_only
from mmcv.utils.logging import get_logger, logger_initialized, print_log
class BaseModule(nn.Module, metaclass=ABCMeta):
"""Base module for all modules in openmmlab.
``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional
functionality of parameter initialization. Compared with
``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
- ``init_cfg``: the config to control the initialization.
- ``init_weights``: The function of parameter initialization and recording
initialization information.
- ``_params_init_info``: Used to track the parameter initialization
information. This attribute only exists during executing the
``init_weights``.
Args:
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, init_cfg: Optional[dict] = None):
"""Initialize BaseModule, inherited from `torch.nn.Module`"""
# NOTE init_cfg can be defined in different levels, but init_cfg
# in low levels has a higher priority.
super().__init__()
# define default value of init_cfg instead of hard code
# in init_weights() function
self._is_init = False
self.init_cfg = copy.deepcopy(init_cfg)
# Backward compatibility in derived classes
# if pretrained is not None:
# warnings.warn('DeprecationWarning: pretrained is a deprecated \
# key, please consider using init_cfg')
# self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
@property
def is_init(self) -> bool:
return self._is_init
def init_weights(self) -> None:
"""Initialize the weights."""
is_top_level_module = False
# check if it is top-level module
if not hasattr(self, '_params_init_info'):
# The `_params_init_info` is used to record the initialization
# information of the parameters
# the key should be the obj:`nn.Parameter` of model and the value
# should be a dict containing
# - init_info (str): The string that describes the initialization.
# - tmp_mean_value (FloatTensor): The mean of the parameter,
# which indicates whether the parameter has been modified.
# this attribute would be deleted after all parameters
# is initialized.
self._params_init_info: defaultdict = defaultdict(dict)
is_top_level_module = True
# Initialize the `_params_init_info`,
# When detecting the `tmp_mean_value` of
# the corresponding parameter is changed, update related
# initialization information
for name, param in self.named_parameters():
self._params_init_info[param][
'init_info'] = f'The value is the same before and ' \
f'after calling `init_weights` ' \
f'of {self.__class__.__name__} '
self._params_init_info[param][
'tmp_mean_value'] = param.data.mean()
# pass `params_init_info` to all submodules
# All submodules share the same `params_init_info`,
# so it will be updated when parameters are
# modified at any level of the model.
for sub_module in self.modules():
sub_module._params_init_info = self._params_init_info
# Get the initialized logger, if not exist,
# create a logger named `mmcv`
logger_names = list(logger_initialized.keys())
logger_name = logger_names[0] if logger_names else 'mmcv'
from ..cnn import initialize
from ..cnn.utils.weight_init import update_init_info
module_name = self.__class__.__name__
if not self._is_init:
if self.init_cfg:
print_log(
f'initialize {module_name} with init_cfg {self.init_cfg}',
logger=logger_name)
initialize(self, self.init_cfg)
if isinstance(self.init_cfg, dict):
# prevent the parameters of
# the pre-trained model
# from being overwritten by
# the `init_weights`
if self.init_cfg['type'] == 'Pretrained':
return
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights()
# users may overload the `init_weights`
update_init_info(
m,
init_info=f'Initialized by '
f'user-defined `init_weights`'
f' in {m.__class__.__name__} ')
self._is_init = True
else:
warnings.warn(f'init_weights of {self.__class__.__name__} has '
f'been called more than once.')
if is_top_level_module:
self._dump_init_info(logger_name)
for sub_module in self.modules():
del sub_module._params_init_info
@master_only
def _dump_init_info(self, logger_name: str) -> None:
"""Dump the initialization information to a file named
`initialization.log.json` in workdir.
Args:
logger_name (str): The name of logger.
"""
logger = get_logger(logger_name)
with_file_handler = False
# dump the information to the logger file if there is a `FileHandler`
for handler in logger.handlers:
if isinstance(handler, FileHandler):
handler.stream.write(
'Name of parameter - Initialization information\n')
for name, param in self.named_parameters():
handler.stream.write(
f'\n{name} - {param.shape}: '
f"\n{self._params_init_info[param]['init_info']} \n")
handler.stream.flush()
with_file_handler = True
if not with_file_handler:
for name, param in self.named_parameters():
print_log(
f'\n{name} - {param.shape}: '
f"\n{self._params_init_info[param]['init_info']} \n ",
logger=logger_name)
def __repr__(self):
s = super().__repr__()
if self.init_cfg:
s += f'\ninit_cfg={self.init_cfg}'
return s
class Sequential(BaseModule, nn.Sequential):
"""Sequential module in openmmlab.
Args:
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, *args, init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.Sequential.__init__(self, *args)
class ModuleList(BaseModule, nn.ModuleList):
"""ModuleList in openmmlab.
Args:
modules (iterable, optional): an iterable of modules to add.
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self,
modules: Optional[Iterable] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleList.__init__(self, modules)
class ModuleDict(BaseModule, nn.ModuleDict):
"""ModuleDict in openmmlab.
Args:
modules (dict, optional): a mapping (dictionary) of (string: module)
or an iterable of key-value pairs of type (string, module).
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self,
modules: Optional[dict] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleDict.__init__(self, modules)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment