Unverified Commit 0e157c31 authored by VVsssssk's avatar VVsssssk Committed by GitHub
Browse files

[Fix]: fix mmcv.model to mmengine.model (#1750)

parent 009d5d6e
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.cnn import NORM_LAYERS
from mmcv.runner import force_fp32
from mmengine.registry import MODELS
from torch import distributed as dist
from torch import nn as nn
from torch.autograd.function import Function
......@@ -25,7 +25,7 @@ class AllReduce(Function):
return grad_output
@NORM_LAYERS.register_module('naiveSyncBN1d')
@MODELS.register_module('naiveSyncBN1d')
class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
"""Synchronized Batch Normalization for 3D Tensors.
......@@ -98,7 +98,7 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
return output
@NORM_LAYERS.register_module('naiveSyncBN2d')
@MODELS.register_module('naiveSyncBN2d')
class NaiveSyncBatchNorm2d(nn.BatchNorm2d):
"""Synchronized Batch Normalization for 4D Tensors.
......
# Copyright (c) OpenMMLab. All rights reserved.
from .compat_spconv2 import register_spconv2
from .overwrite_spconv import register_spconv2
try:
import spconv
......
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from mmcv.cnn.bricks.registry import CONV_LAYERS
from mmengine.registry import MODELS
from torch.nn.parameter import Parameter
......@@ -17,47 +17,28 @@ def register_spconv2():
except ImportError:
return False
else:
CONV_LAYERS._register_module(SparseConv2d, 'SparseConv2d', force=True)
CONV_LAYERS._register_module(SparseConv3d, 'SparseConv3d', force=True)
CONV_LAYERS._register_module(SparseConv4d, 'SparseConv4d', force=True)
MODELS._register_module(SparseConv2d, 'SparseConv2d', force=True)
MODELS._register_module(SparseConv3d, 'SparseConv3d', force=True)
MODELS._register_module(SparseConv4d, 'SparseConv4d', force=True)
CONV_LAYERS._register_module(
MODELS._register_module(
SparseConvTranspose2d, 'SparseConvTranspose2d', force=True)
CONV_LAYERS._register_module(
MODELS._register_module(
SparseConvTranspose3d, 'SparseConvTranspose3d', force=True)
CONV_LAYERS._register_module(
MODELS._register_module(
SparseInverseConv2d, 'SparseInverseConv2d', force=True)
CONV_LAYERS._register_module(
MODELS._register_module(
SparseInverseConv3d, 'SparseInverseConv3d', force=True)
CONV_LAYERS._register_module(SubMConv2d, 'SubMConv2d', force=True)
CONV_LAYERS._register_module(SubMConv3d, 'SubMConv3d', force=True)
CONV_LAYERS._register_module(SubMConv4d, 'SubMConv4d', force=True)
MODELS._register_module(SubMConv2d, 'SubMConv2d', force=True)
MODELS._register_module(SubMConv3d, 'SubMConv3d', force=True)
MODELS._register_module(SubMConv4d, 'SubMConv4d', force=True)
SparseModule._version = 2
SparseModule._load_from_state_dict = _load_from_state_dict
SparseModule._save_to_state_dict = _save_to_state_dict
return True
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""Rewrite this func to compat the convolutional kernel weights between
spconv 1.x in MMCV and 2.x in spconv2.x.
Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
while those in spcon2.x is in (out_channel,D,H,W,in_channel).
"""
for name, param in self._parameters.items():
if param is not None:
param = param if keep_vars else param.detach()
if name == 'weight':
dims = list(range(1, len(param.shape))) + [0]
param = param.permute(*dims)
destination[prefix + name] = param
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""Rewrite this func to compat the convolutional kernel weights between
......@@ -66,6 +47,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
while those in spcon2.x is in (out_channel,D,H,W,in_channel).
"""
version = local_metadata.get('version', None)
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs)
......@@ -83,6 +65,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
# 0.3.* to version 0.4+
if len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if version != 2:
dims = [len(input_param.shape) - 1] + list(
range(len(input_param.shape) - 1))
input_param = input_param.permute(*dims)
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING, MultiheadAttention
from mmcv.cnn.bricks.transformer import MultiheadAttention
from mmengine.registry import MODELS
from torch import nn as nn
@ATTENTION.register_module()
@MODELS.register_module()
class GroupFree3DMHA(MultiheadAttention):
"""A warpper for torch.nn.MultiheadAttention for GroupFree3D.
......@@ -108,7 +108,7 @@ class GroupFree3DMHA(MultiheadAttention):
**kwargs)
@POSITIONAL_ENCODING.register_module()
@MODELS.register_module()
class ConvBNPositionalEncoding(nn.Module):
"""Absolute position embedding with Conv learning.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment