write_spconv2.py 4.52 KB
Newer Older
VVsssssk's avatar
VVsssssk committed
1
2
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
3
from typing import List, OrderedDict
VVsssssk's avatar
VVsssssk committed
4

5
from mmengine.registry import MODELS
VVsssssk's avatar
VVsssssk committed
6
7
8
from torch.nn.parameter import Parameter


9
def register_spconv2() -> bool:
VVsssssk's avatar
VVsssssk committed
10
11
12
13
14
15
16
17
18
19
20
    """This func registers spconv2.0 spconv ops to overwrite the default mmcv
    spconv ops."""
    try:
        from spconv.pytorch import (SparseConv2d, SparseConv3d, SparseConv4d,
                                    SparseConvTranspose2d,
                                    SparseConvTranspose3d, SparseInverseConv2d,
                                    SparseInverseConv3d, SparseModule,
                                    SubMConv2d, SubMConv3d, SubMConv4d)
    except ImportError:
        return False
    else:
21
22
23
        MODELS._register_module(SparseConv2d, 'SparseConv2d', force=True)
        MODELS._register_module(SparseConv3d, 'SparseConv3d', force=True)
        MODELS._register_module(SparseConv4d, 'SparseConv4d', force=True)
VVsssssk's avatar
VVsssssk committed
24

25
        MODELS._register_module(
VVsssssk's avatar
VVsssssk committed
26
            SparseConvTranspose2d, 'SparseConvTranspose2d', force=True)
27
        MODELS._register_module(
VVsssssk's avatar
VVsssssk committed
28
29
            SparseConvTranspose3d, 'SparseConvTranspose3d', force=True)

30
        MODELS._register_module(
VVsssssk's avatar
VVsssssk committed
31
            SparseInverseConv2d, 'SparseInverseConv2d', force=True)
32
        MODELS._register_module(
VVsssssk's avatar
VVsssssk committed
33
34
            SparseInverseConv3d, 'SparseInverseConv3d', force=True)

35
36
37
38
        MODELS._register_module(SubMConv2d, 'SubMConv2d', force=True)
        MODELS._register_module(SubMConv3d, 'SubMConv3d', force=True)
        MODELS._register_module(SubMConv4d, 'SubMConv4d', force=True)
        SparseModule._version = 2
VVsssssk's avatar
VVsssssk committed
39
40
41
42
        SparseModule._load_from_state_dict = _load_from_state_dict
        return True


43
44
45
46
def _load_from_state_dict(self, state_dict: OrderedDict, prefix: str,
                          local_metadata: dict, strict: bool,
                          missing_keys: List[str], unexpected_keys: List[str],
                          error_msgs: List[str]) -> None:
VVsssssk's avatar
VVsssssk committed
47
48
49
50
51
52
    """Rewrite this func to compat the convolutional kernel weights between
    spconv 1.x in MMCV and 2.x in spconv2.x.

    Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
    while those in spcon2.x is in (out_channel,D,H,W,in_channel).
    """
53
    version = local_metadata.get('version', None)
VVsssssk's avatar
VVsssssk committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    for hook in self._load_state_dict_pre_hooks.values():
        hook(state_dict, prefix, local_metadata, strict, missing_keys,
             unexpected_keys, error_msgs)

    local_name_params = itertools.chain(self._parameters.items(),
                                        self._buffers.items())
    local_state = {k: v.data for k, v in local_name_params if v is not None}

    for name, param in local_state.items():
        key = prefix + name
        if key in state_dict:
            input_param = state_dict[key]

            # Backward compatibility: loading 1-dim tensor from
            # 0.3.* to version 0.4+
            if len(param.shape) == 0 and len(input_param.shape) == 1:
                input_param = input_param[0]
71
72
73
74
            if version != 2:
                dims = [len(input_param.shape) - 1] + list(
                    range(len(input_param.shape) - 1))
                input_param = input_param.permute(*dims)
VVsssssk's avatar
VVsssssk committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
            if input_param.shape != param.shape:
                # local shape should match the one in checkpoint
                error_msgs.append(
                    f'size mismatch for {key}: copying a param with '
                    f'shape {key, input_param.shape} from checkpoint,'
                    f'the shape in current model is {param.shape}.')
                continue

            if isinstance(input_param, Parameter):
                # backwards compatibility for serialized parameters
                input_param = input_param.data
            try:
                param.copy_(input_param)
            except Exception:
                error_msgs.append(
                    f'While copying the parameter named "{key}", whose '
                    f'dimensions in the model are {param.size()} and whose '
                    f'dimensions in the checkpoint are {input_param.size()}.')
        elif strict:
            missing_keys.append(key)

    if strict:
        for key, input_param in state_dict.items():
            if key.startswith(prefix):
                input_name = key[len(prefix):]
                input_name = input_name.split(
                    '.', 1)[0]  # get the name of param/buffer/child
                if input_name not in self._modules \
                        and input_name not in local_state:
                    unexpected_keys.append(key)