"tools/vscode:/vscode.git/clone" did not exist on "527e26df0806c4daa68c0692bc032a7fca43c2ef"
Unverified Commit 0e157c31 authored by VVsssssk's avatar VVsssssk Committed by GitHub
Browse files

[Fix]: fix mmcv.model to mmengine.model (#1750)

parent 009d5d6e
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from mmcv.cnn import NORM_LAYERS
from mmcv.runner import force_fp32 from mmcv.runner import force_fp32
from mmengine.registry import MODELS
from torch import distributed as dist from torch import distributed as dist
from torch import nn as nn from torch import nn as nn
from torch.autograd.function import Function from torch.autograd.function import Function
...@@ -25,7 +25,7 @@ class AllReduce(Function): ...@@ -25,7 +25,7 @@ class AllReduce(Function):
return grad_output return grad_output
@NORM_LAYERS.register_module('naiveSyncBN1d') @MODELS.register_module('naiveSyncBN1d')
class NaiveSyncBatchNorm1d(nn.BatchNorm1d): class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
"""Synchronized Batch Normalization for 3D Tensors. """Synchronized Batch Normalization for 3D Tensors.
...@@ -98,7 +98,7 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d): ...@@ -98,7 +98,7 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
return output return output
@NORM_LAYERS.register_module('naiveSyncBN2d') @MODELS.register_module('naiveSyncBN2d')
class NaiveSyncBatchNorm2d(nn.BatchNorm2d): class NaiveSyncBatchNorm2d(nn.BatchNorm2d):
"""Synchronized Batch Normalization for 4D Tensors. """Synchronized Batch Normalization for 4D Tensors.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .compat_spconv2 import register_spconv2 from .overwrite_spconv import register_spconv2
try: try:
import spconv import spconv
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import itertools import itertools
from mmcv.cnn.bricks.registry import CONV_LAYERS from mmengine.registry import MODELS
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -17,47 +17,28 @@ def register_spconv2(): ...@@ -17,47 +17,28 @@ def register_spconv2():
except ImportError: except ImportError:
return False return False
else: else:
CONV_LAYERS._register_module(SparseConv2d, 'SparseConv2d', force=True) MODELS._register_module(SparseConv2d, 'SparseConv2d', force=True)
CONV_LAYERS._register_module(SparseConv3d, 'SparseConv3d', force=True) MODELS._register_module(SparseConv3d, 'SparseConv3d', force=True)
CONV_LAYERS._register_module(SparseConv4d, 'SparseConv4d', force=True) MODELS._register_module(SparseConv4d, 'SparseConv4d', force=True)
CONV_LAYERS._register_module( MODELS._register_module(
SparseConvTranspose2d, 'SparseConvTranspose2d', force=True) SparseConvTranspose2d, 'SparseConvTranspose2d', force=True)
CONV_LAYERS._register_module( MODELS._register_module(
SparseConvTranspose3d, 'SparseConvTranspose3d', force=True) SparseConvTranspose3d, 'SparseConvTranspose3d', force=True)
CONV_LAYERS._register_module( MODELS._register_module(
SparseInverseConv2d, 'SparseInverseConv2d', force=True) SparseInverseConv2d, 'SparseInverseConv2d', force=True)
CONV_LAYERS._register_module( MODELS._register_module(
SparseInverseConv3d, 'SparseInverseConv3d', force=True) SparseInverseConv3d, 'SparseInverseConv3d', force=True)
CONV_LAYERS._register_module(SubMConv2d, 'SubMConv2d', force=True) MODELS._register_module(SubMConv2d, 'SubMConv2d', force=True)
CONV_LAYERS._register_module(SubMConv3d, 'SubMConv3d', force=True) MODELS._register_module(SubMConv3d, 'SubMConv3d', force=True)
CONV_LAYERS._register_module(SubMConv4d, 'SubMConv4d', force=True) MODELS._register_module(SubMConv4d, 'SubMConv4d', force=True)
SparseModule._version = 2
SparseModule._load_from_state_dict = _load_from_state_dict SparseModule._load_from_state_dict = _load_from_state_dict
SparseModule._save_to_state_dict = _save_to_state_dict
return True return True
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""Rewrite this func to compat the convolutional kernel weights between
spconv 1.x in MMCV and 2.x in spconv2.x.
Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
while those in spcon2.x is in (out_channel,D,H,W,in_channel).
"""
for name, param in self._parameters.items():
if param is not None:
param = param if keep_vars else param.detach()
if name == 'weight':
dims = list(range(1, len(param.shape))) + [0]
param = param.permute(*dims)
destination[prefix + name] = param
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs): missing_keys, unexpected_keys, error_msgs):
"""Rewrite this func to compat the convolutional kernel weights between """Rewrite this func to compat the convolutional kernel weights between
...@@ -66,6 +47,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, ...@@ -66,6 +47,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) , Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
while those in spcon2.x is in (out_channel,D,H,W,in_channel). while those in spcon2.x is in (out_channel,D,H,W,in_channel).
""" """
version = local_metadata.get('version', None)
for hook in self._load_state_dict_pre_hooks.values(): for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, hook(state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs) unexpected_keys, error_msgs)
...@@ -83,9 +65,10 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, ...@@ -83,9 +65,10 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
# 0.3.* to version 0.4+ # 0.3.* to version 0.4+
if len(param.shape) == 0 and len(input_param.shape) == 1: if len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0] input_param = input_param[0]
dims = [len(input_param.shape) - 1] + list( if version != 2:
range(len(input_param.shape) - 1)) dims = [len(input_param.shape) - 1] + list(
input_param = input_param.permute(*dims) range(len(input_param.shape) - 1))
input_param = input_param.permute(*dims)
if input_param.shape != param.shape: if input_param.shape != param.shape:
# local shape should match the one in checkpoint # local shape should match the one in checkpoint
error_msgs.append( error_msgs.append(
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn.bricks.registry import ATTENTION from mmcv.cnn.bricks.transformer import MultiheadAttention
from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING, MultiheadAttention from mmengine.registry import MODELS
from torch import nn as nn from torch import nn as nn
@ATTENTION.register_module() @MODELS.register_module()
class GroupFree3DMHA(MultiheadAttention): class GroupFree3DMHA(MultiheadAttention):
"""A warpper for torch.nn.MultiheadAttention for GroupFree3D. """A warpper for torch.nn.MultiheadAttention for GroupFree3D.
...@@ -108,7 +108,7 @@ class GroupFree3DMHA(MultiheadAttention): ...@@ -108,7 +108,7 @@ class GroupFree3DMHA(MultiheadAttention):
**kwargs) **kwargs)
@POSITIONAL_ENCODING.register_module() @MODELS.register_module()
class ConvBNPositionalEncoding(nn.Module): class ConvBNPositionalEncoding(nn.Module):
"""Absolute position embedding with Conv learning. """Absolute position embedding with Conv learning.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment