spconv_utils.py 896 Bytes
Newer Older
1
2
from typing import Set

3
4
5
6
7
try:
    import spconv.pytorch as spconv
except:
    import spconv as spconv

8
9
10
11
12
13
14
15
16
17
18
import torch.nn as nn


def find_all_spconv_keys(model: nn.Module, prefix="") -> Set[str]:
    """
    Finds all spconv keys that need to have weight's transposed
    """
    found_keys: Set[str] = set()
    for name, child in model.named_children():
        new_prefix = f"{prefix}.{name}" if prefix != "" else name

19
        if isinstance(child, spconv.conv.SparseConvolution):
20
21
22
23
24
25
            new_prefix = f"{new_prefix}.weight"
            found_keys.add(new_prefix)

        found_keys.update(find_all_spconv_keys(child, prefix=new_prefix))

    return found_keys
26
27
28
29
30
31
32
33
34


def replace_feature(out, new_features):
    if "replace_feature" in out.__dir__():
        # spconv 2.x behaviour
        return out.replace_feature(new_features)
    else:
        out.features = new_features
        return out