torchsparse_wrapper.py 1.02 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import torch.nn as nn
3
4
5
6
7
8
9
10
from mmengine.registry import MODELS


def register_torchsparse() -> bool:
    """This func registers torchsparse modules."""
    try:
        from torchsparse.nn import (BatchNorm, Conv3d, GroupNorm, LeakyReLU,
                                    ReLU)
11
12
        from torchsparse.nn.utils import fapply
        from torchsparse.tensor import SparseTensor
13
14
15
    except ImportError:
        return False
    else:
16
17
18
19
20
21

        class SyncBatchNorm(nn.SyncBatchNorm):

            def forward(self, input: SparseTensor) -> SparseTensor:
                return fapply(input, super().forward)

22
        MODELS._register_module(Conv3d, 'TorchSparseConv3d')
23
24
25
        MODELS._register_module(BatchNorm, 'TorchSparseBN')
        MODELS._register_module(SyncBatchNorm, 'TorchSparseSyncBN')
        MODELS._register_module(GroupNorm, 'TorchSparseGN')
26
27
28
        MODELS._register_module(ReLU, 'TorchSparseReLU')
        MODELS._register_module(LeakyReLU, 'TorchSparseLeakyReLU')
        return True