write_spconv2.py 5.2 KB
Newer Older
VVsssssk's avatar
VVsssssk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright (c) OpenMMLab. All rights reserved.
import itertools

from mmcv.cnn.bricks.registry import CONV_LAYERS
from torch.nn.parameter import Parameter


def register_spconv2():
    """This func registers spconv2.0 spconv ops to overwrite the default mmcv
    spconv ops."""
    try:
        from spconv.pytorch import (SparseConv2d, SparseConv3d, SparseConv4d,
                                    SparseConvTranspose2d,
                                    SparseConvTranspose3d, SparseInverseConv2d,
                                    SparseInverseConv3d, SparseModule,
                                    SubMConv2d, SubMConv3d, SubMConv4d)
    except ImportError:
        return False
    else:
        CONV_LAYERS._register_module(SparseConv2d, 'SparseConv2d', force=True)
        CONV_LAYERS._register_module(SparseConv3d, 'SparseConv3d', force=True)
        CONV_LAYERS._register_module(SparseConv4d, 'SparseConv4d', force=True)

        CONV_LAYERS._register_module(
            SparseConvTranspose2d, 'SparseConvTranspose2d', force=True)
        CONV_LAYERS._register_module(
            SparseConvTranspose3d, 'SparseConvTranspose3d', force=True)

        CONV_LAYERS._register_module(
            SparseInverseConv2d, 'SparseInverseConv2d', force=True)
        CONV_LAYERS._register_module(
            SparseInverseConv3d, 'SparseInverseConv3d', force=True)

        CONV_LAYERS._register_module(SubMConv2d, 'SubMConv2d', force=True)
        CONV_LAYERS._register_module(SubMConv3d, 'SubMConv3d', force=True)
        CONV_LAYERS._register_module(SubMConv4d, 'SubMConv4d', force=True)
        SparseModule._load_from_state_dict = _load_from_state_dict
        SparseModule._save_to_state_dict = _save_to_state_dict
        return True


def _save_to_state_dict(self, destination, prefix, keep_vars):
    """Rewrite this func to compat the convolutional kernel weights between
    spconv 1.x in MMCV and 2.x in spconv2.x.

    Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
    while those in spcon2.x is in (out_channel,D,H,W,in_channel).
    """
    for name, param in self._parameters.items():
        if param is not None:
            param = param if keep_vars else param.detach()
            if name == 'weight':
                dims = list(range(1, len(param.shape))) + [0]
                param = param.permute(*dims)
            destination[prefix + name] = param
    for name, buf in self._buffers.items():
        if buf is not None and name not in self._non_persistent_buffers_set:
            destination[prefix + name] = buf if keep_vars else buf.detach()


def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                          missing_keys, unexpected_keys, error_msgs):
    """Rewrite this func to compat the convolutional kernel weights between
    spconv 1.x in MMCV and 2.x in spconv2.x.

    Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
    while those in spcon2.x is in (out_channel,D,H,W,in_channel).
    """
    for hook in self._load_state_dict_pre_hooks.values():
        hook(state_dict, prefix, local_metadata, strict, missing_keys,
             unexpected_keys, error_msgs)

    local_name_params = itertools.chain(self._parameters.items(),
                                        self._buffers.items())
    local_state = {k: v.data for k, v in local_name_params if v is not None}

    for name, param in local_state.items():
        key = prefix + name
        if key in state_dict:
            input_param = state_dict[key]

            # Backward compatibility: loading 1-dim tensor from
            # 0.3.* to version 0.4+
            if len(param.shape) == 0 and len(input_param.shape) == 1:
                input_param = input_param[0]
            dims = [len(input_param.shape) - 1] + list(
                range(len(input_param.shape) - 1))
            input_param = input_param.permute(*dims)
            if input_param.shape != param.shape:
                # local shape should match the one in checkpoint
                error_msgs.append(
                    f'size mismatch for {key}: copying a param with '
                    f'shape {key, input_param.shape} from checkpoint,'
                    f'the shape in current model is {param.shape}.')
                continue

            if isinstance(input_param, Parameter):
                # backwards compatibility for serialized parameters
                input_param = input_param.data
            try:
                param.copy_(input_param)
            except Exception:
                error_msgs.append(
                    f'While copying the parameter named "{key}", whose '
                    f'dimensions in the model are {param.size()} and whose '
                    f'dimensions in the checkpoint are {input_param.size()}.')
        elif strict:
            missing_keys.append(key)

    if strict:
        for key, input_param in state_dict.items():
            if key.startswith(prefix):
                input_name = key[len(prefix):]
                input_name = input_name.split(
                    '.', 1)[0]  # get the name of param/buffer/child
                if input_name not in self._modules \
                        and input_name not in local_state:
                    unexpected_keys.append(key)