write_spconv2.py 4.36 KB
Newer Older
VVsssssk's avatar
VVsssssk committed
1
2
3
# Copyright (c) OpenMMLab. All rights reserved.
import itertools

4
from mmengine.registry import MODELS
VVsssssk's avatar
VVsssssk committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from torch.nn.parameter import Parameter


def register_spconv2():
    """This func registers spconv2.0 spconv ops to overwrite the default mmcv
    spconv ops."""
    try:
        from spconv.pytorch import (SparseConv2d, SparseConv3d, SparseConv4d,
                                    SparseConvTranspose2d,
                                    SparseConvTranspose3d, SparseInverseConv2d,
                                    SparseInverseConv3d, SparseModule,
                                    SubMConv2d, SubMConv3d, SubMConv4d)
    except ImportError:
        return False
    else:
20
21
22
        MODELS._register_module(SparseConv2d, 'SparseConv2d', force=True)
        MODELS._register_module(SparseConv3d, 'SparseConv3d', force=True)
        MODELS._register_module(SparseConv4d, 'SparseConv4d', force=True)
VVsssssk's avatar
VVsssssk committed
23

24
        MODELS._register_module(
VVsssssk's avatar
VVsssssk committed
25
            SparseConvTranspose2d, 'SparseConvTranspose2d', force=True)
26
        MODELS._register_module(
VVsssssk's avatar
VVsssssk committed
27
28
            SparseConvTranspose3d, 'SparseConvTranspose3d', force=True)

29
        MODELS._register_module(
VVsssssk's avatar
VVsssssk committed
30
            SparseInverseConv2d, 'SparseInverseConv2d', force=True)
31
        MODELS._register_module(
VVsssssk's avatar
VVsssssk committed
32
33
            SparseInverseConv3d, 'SparseInverseConv3d', force=True)

34
35
36
37
        MODELS._register_module(SubMConv2d, 'SubMConv2d', force=True)
        MODELS._register_module(SubMConv3d, 'SubMConv3d', force=True)
        MODELS._register_module(SubMConv4d, 'SubMConv4d', force=True)
        SparseModule._version = 2
VVsssssk's avatar
VVsssssk committed
38
39
40
41
42
43
44
45
46
47
48
49
        SparseModule._load_from_state_dict = _load_from_state_dict
        return True


def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                          missing_keys, unexpected_keys, error_msgs):
    """Rewrite this func to compat the convolutional kernel weights between
    spconv 1.x in MMCV and 2.x in spconv2.x.

    Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
    while those in spcon2.x is in (out_channel,D,H,W,in_channel).
    """
50
    version = local_metadata.get('version', None)
VVsssssk's avatar
VVsssssk committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    for hook in self._load_state_dict_pre_hooks.values():
        hook(state_dict, prefix, local_metadata, strict, missing_keys,
             unexpected_keys, error_msgs)

    local_name_params = itertools.chain(self._parameters.items(),
                                        self._buffers.items())
    local_state = {k: v.data for k, v in local_name_params if v is not None}

    for name, param in local_state.items():
        key = prefix + name
        if key in state_dict:
            input_param = state_dict[key]

            # Backward compatibility: loading 1-dim tensor from
            # 0.3.* to version 0.4+
            if len(param.shape) == 0 and len(input_param.shape) == 1:
                input_param = input_param[0]
68
69
70
71
            if version != 2:
                dims = [len(input_param.shape) - 1] + list(
                    range(len(input_param.shape) - 1))
                input_param = input_param.permute(*dims)
VVsssssk's avatar
VVsssssk committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            if input_param.shape != param.shape:
                # local shape should match the one in checkpoint
                error_msgs.append(
                    f'size mismatch for {key}: copying a param with '
                    f'shape {key, input_param.shape} from checkpoint,'
                    f'the shape in current model is {param.shape}.')
                continue

            if isinstance(input_param, Parameter):
                # backwards compatibility for serialized parameters
                input_param = input_param.data
            try:
                param.copy_(input_param)
            except Exception:
                error_msgs.append(
                    f'While copying the parameter named "{key}", whose '
                    f'dimensions in the model are {param.size()} and whose '
                    f'dimensions in the checkpoint are {input_param.size()}.')
        elif strict:
            missing_keys.append(key)

    if strict:
        for key, input_param in state_dict.items():
            if key.startswith(prefix):
                input_name = key[len(prefix):]
                input_name = input_name.split(
                    '.', 1)[0]  # get the name of param/buffer/child
                if input_name not in self._modules \
                        and input_name not in local_state:
                    unexpected_keys.append(key)