Unverified Commit 0b42e351 authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Feature] Add torchsparse wrapper (#2321)

* fix polarmix UT

* add torchsparse block

* fix pytest skip

* add installtion in get_started

* fix name

* fix UT bug

* update doc string

* add omit sudo install
parent 0c22c625
......@@ -97,6 +97,21 @@ Note:
pip install -U git+https://github.com/NVIDIA/MinkowskiEngine -v --no-deps --install-option="--blas_include_dirs=/opt/conda/include" --install-option="--blas=openblas"
```
We also support `Torchsparse` as a sparse convolution backend. If necessary please follow original [installation guide](https://github.com/mit-han-lab/torchsparse#installation) or use `pip` to install it:
```shell
sudo apt-get install libsparsehash-dev
pip install --upgrade git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0
```
or omit sudo install by following command:
```shell
conda install -c bioconda sparsehash
export CPLUS_INCLUDE_PATH=CPLUS_INCLUDE_PATH:${YOUR_CONDA_ENVS_DIR}/include
pip install --upgrade git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0
```
3. The code can not be built for CPU only environment (where CUDA isn't available) for now.
### Verify the Installation
......
......@@ -95,6 +95,21 @@ mim install "mmdet3d>=1.1.0rc0"
pip install -U git+https://github.com/NVIDIA/MinkowskiEngine -v --no-deps --install-option="--blas_include_dirs=/opt/conda/include" --install-option="--blas=openblas"
```
我们还支持 `Torchsparse` 作为稀疏卷积的后端。如果需要,请参考[安装指南](https://github.com/mit-han-lab/torchsparse#installation) 或者使用 `pip` 来安装:
```shell
sudo apt install libsparsehash-dev
pip install --upgrade git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0
```
或者通过以下安装绕过sudo权限
```shell
conda install -c bioconda sparsehash
export CPLUS_INCLUDE_PATH=CPLUS_INCLUDE_PATH:${YOUR_CONDA_ENVS_DIR}/include
pip install --upgrade git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0
```
3. 我们的代码目前不能在只有 CPU 的环境(CUDA 不可用)下编译。
### 验证安装
......
......@@ -14,6 +14,7 @@ from .pointnet_modules import (PAConvCUDASAModule, PAConvCUDASAModuleMSG,
build_sa_module)
from .sparse_block import (SparseBasicBlock, SparseBottleneck,
make_sparse_convmodule)
from .torchsparse_block import TorchSparseConvModule, TorchSparseResidualBlock
from .transformer import GroupFree3DMHA
from .vote_module import VoteModule
......@@ -26,5 +27,6 @@ __all__ = [
'MLP', 'box3d_multiclass_nms', 'aligned_3d_nms', 'circle_nms', 'nms_bev',
'nms_normal_bev', 'build_sa_module', 'PointSAModuleMSG', 'PointSAModule',
'PointFPModule', 'PAConvSAModule', 'PAConvSAModuleMSG',
'PAConvCUDASAModule', 'PAConvCUDASAModuleMSG'
'PAConvCUDASAModule', 'PAConvCUDASAModuleMSG', 'TorchSparseConvModule',
'TorchSparseResidualBlock'
]
# Copyright (c) OpenMMLab. All rights reserved.
from .torchsparse_wrapper import register_torchsparse
try:
import torchsparse # noqa
except ImportError:
IS_TORCHSPARSE_AVAILABLE = False
else:
IS_TORCHSPARSE_AVAILABLE = register_torchsparse()
__all__ = ['IS_TORCHSPARSE_AVAILABLE']
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.registry import MODELS
def register_torchsparse() -> bool:
"""This func registers torchsparse modules."""
try:
from torchsparse.nn import (BatchNorm, Conv3d, GroupNorm, LeakyReLU,
ReLU)
except ImportError:
return False
else:
MODELS._register_module(Conv3d, 'TorchSparseConv3d')
MODELS._register_module(BatchNorm, 'TorchSparseBatchNorm')
MODELS._register_module(GroupNorm, 'TorchSparseGroupNorm')
MODELS._register_module(ReLU, 'TorchSparseReLU')
MODELS._register_module(LeakyReLU, 'TorchSparseLeakyReLU')
return True
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Union
from mmengine.model import BaseModule
from torch import nn
from mmdet3d.utils import OptConfigType
from .torchsparse import IS_TORCHSPARSE_AVAILABLE
if IS_TORCHSPARSE_AVAILABLE:
import torchsparse.nn as spnn
from torchsparse.tensor import SparseTensor
else:
SparseTensor = None
class TorchSparseConvModule(BaseModule):
"""A torchsparse conv block that bundles conv/norm/activation layers.
Args:
in_channels (int): In channels of block.
out_channels (int): Out channels of block.
kernel_size (int or Tuple[int]): Kernel_size of block.
stride (int or Tuple[int]): Stride of the first block. Defaults to 1.
dilation (int): Dilation of block. Defaults to 1.
transposed (bool): Whether use transposed convolution operator.
Defaults to False.
init_cfg (:obj:`ConfigDict` or dict, optional): Initialization config.
Defaults to None.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
dilation: int = 1,
bias: bool = False,
transposed: bool = False,
init_cfg: OptConfigType = None,
) -> None:
super().__init__(init_cfg)
self.net = nn.Sequential(
spnn.Conv3d(in_channels, out_channels, kernel_size, stride,
dilation, bias, transposed),
spnn.BatchNorm(out_channels),
spnn.ReLU(inplace=True),
)
def forward(self, x: SparseTensor) -> SparseTensor:
out = self.net(x)
return out
class TorchSparseResidualBlock(BaseModule):
"""Torchsparse residual basic block for MinkUNet.
Args:
in_channels (int): In channels of block.
out_channels (int): Out channels of block.
kernel_size (int or Tuple[int]): Kernel_size of block.
stride (int or Tuple[int]): Stride of the first block. Defaults to 1.
dilation (int): Dilation of block. Defaults to 1.
init_cfg (:obj:`ConfigDict` or dict, optional): Initialization config.
Defaults to None.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
dilation: int = 1,
bias: bool = False,
init_cfg: OptConfigType = None,
) -> None:
super().__init__(init_cfg)
self.net = nn.Sequential(
spnn.Conv3d(in_channels, out_channels, kernel_size, stride,
dilation, bias),
spnn.BatchNorm(out_channels),
spnn.ReLU(inplace=True),
spnn.Conv3d(
out_channels,
out_channels,
kernel_size,
stride=1,
dilation=dilation,
bias=bias),
spnn.BatchNorm(out_channels),
)
if in_channels == out_channels and stride == 1:
self.downsample = nn.Identity()
else:
self.downsample = nn.Sequential(
spnn.Conv3d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
dilation=dilation,
bias=bias),
spnn.BatchNorm(out_channels),
)
self.relu = spnn.ReLU(inplace=True)
def forward(self, x: SparseTensor) -> SparseTensor:
out = self.relu(self.net(x) + self.downsample(x))
return out
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
if IS_TORCHSPARSE_AVAILABLE:
from torchsparse import SparseTensor
from mmdet3d.models.layers.torchsparse_block import (
TorchSparseConvModule, TorchSparseResidualBlock)
else:
pytest.skip('test requires Torchsparse', allow_module_level=True)
def test_TorchsparseConvModule():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
voxel_features = torch.tensor(
[[6.56126, 0.9648336, -1.7339306, 0.315],
[6.8162713, -2.480431, -1.3616394, 0.36],
[11.643568, -4.744306, -1.3580885, 0.16],
[23.482342, 6.5036807, 0.5806964, 0.35]],
dtype=torch.float32).cuda() # n, point_features
coordinates = torch.tensor(
[[12, 819, 131, 0], [16, 750, 136, 0], [16, 705, 232, 1],
[35, 930, 469, 1]],
dtype=torch.int32).cuda() # n, 4(ind_x, ind_y, ind_z, batch)
# test
input_sp_tensor = SparseTensor(voxel_features, coordinates)
self = TorchSparseConvModule(4, 4, kernel_size=2, stride=2).cuda()
out_features = self(input_sp_tensor)
assert out_features.F.shape == torch.Size([4, 4])
def test_TorchsparseResidualBlock():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
voxel_features = torch.tensor(
[[6.56126, 0.9648336, -1.7339306, 0.315],
[6.8162713, -2.480431, -1.3616394, 0.36],
[11.643568, -4.744306, -1.3580885, 0.16],
[23.482342, 6.5036807, 0.5806964, 0.35]],
dtype=torch.float32).cuda() # n, point_features
coordinates = torch.tensor(
[[12, 819, 131, 0], [16, 750, 136, 0], [16, 705, 232, 1],
[35, 930, 469, 1]],
dtype=torch.int32).cuda() # n, 4(ind_x, ind_y, ind_z, batch)
# test
input_sp_tensor = SparseTensor(voxel_features, coordinates)
sparse_block0 = TorchSparseResidualBlock(4, 16, kernel_size=3).cuda()
# test forward
out_features = sparse_block0(input_sp_tensor)
assert out_features.F.shape == torch.Size([4, 16])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment