Unverified Commit 65e28327 authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Enhance] Support more Torchsparse blocks (#2532)

* enhance

* enhance

* fix __init__

* fix minku backbone
parent b6571d22
......@@ -5,8 +5,7 @@ from mmengine.model import BaseModule
from mmengine.registry import MODELS
from torch import Tensor, nn
from mmdet3d.models.layers import (TorchSparseConvModule,
TorchSparseResidualBlock)
from mmdet3d.models.layers import TorchSparseBasicBlock, TorchSparseConvModule
from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
from mmdet3d.utils import OptMultiConfig
......@@ -64,11 +63,11 @@ class MinkUNetBackbone(BaseModule):
encoder_channels[i],
kernel_size=2,
stride=2),
TorchSparseResidualBlock(
TorchSparseBasicBlock(
encoder_channels[i],
encoder_channels[i + 1],
kernel_size=3),
TorchSparseResidualBlock(
TorchSparseBasicBlock(
encoder_channels[i + 1],
encoder_channels[i + 1],
kernel_size=3)))
......@@ -82,11 +81,11 @@ class MinkUNetBackbone(BaseModule):
stride=2,
transposed=True),
nn.Sequential(
TorchSparseResidualBlock(
TorchSparseBasicBlock(
decoder_channels[i + 1] + encoder_channels[-2 - i],
decoder_channels[i + 1],
kernel_size=3),
TorchSparseResidualBlock(
TorchSparseBasicBlock(
decoder_channels[i + 1],
decoder_channels[i + 1],
kernel_size=3))
......
......@@ -14,7 +14,8 @@ from .pointnet_modules import (PAConvCUDASAModule, PAConvCUDASAModuleMSG,
build_sa_module)
from .sparse_block import (SparseBasicBlock, SparseBottleneck,
make_sparse_convmodule)
from .torchsparse_block import TorchSparseConvModule, TorchSparseResidualBlock
from .torchsparse_block import (TorchSparseBasicBlock, TorchSparseBottleneck,
TorchSparseConvModule)
from .transformer import GroupFree3DMHA
from .vote_module import VoteModule
......@@ -28,5 +29,5 @@ __all__ = [
'nms_normal_bev', 'build_sa_module', 'PointSAModuleMSG', 'PointSAModule',
'PointFPModule', 'PAConvSAModule', 'PAConvSAModuleMSG',
'PAConvCUDASAModule', 'PAConvCUDASAModuleMSG', 'TorchSparseConvModule',
'TorchSparseResidualBlock'
'TorchSparseBasicBlock', 'TorchSparseBottleneck'
]
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmengine.registry import MODELS
......@@ -7,12 +8,21 @@ def register_torchsparse() -> bool:
try:
from torchsparse.nn import (BatchNorm, Conv3d, GroupNorm, LeakyReLU,
ReLU)
from torchsparse.nn.utils import fapply
from torchsparse.tensor import SparseTensor
except ImportError:
return False
else:
class SyncBatchNorm(nn.SyncBatchNorm):
def forward(self, input: SparseTensor) -> SparseTensor:
return fapply(input, super().forward)
MODELS._register_module(Conv3d, 'TorchSparseConv3d')
MODELS._register_module(BatchNorm, 'TorchSparseBatchNorm')
MODELS._register_module(GroupNorm, 'TorchSparseGroupNorm')
MODELS._register_module(BatchNorm, 'TorchSparseBN')
MODELS._register_module(SyncBatchNorm, 'TorchSparseSyncBN')
MODELS._register_module(GroupNorm, 'TorchSparseGN')
MODELS._register_module(ReLU, 'TorchSparseReLU')
MODELS._register_module(LeakyReLU, 'TorchSparseLeakyReLU')
return True
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Union
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmengine.model import BaseModule
from torch import nn
from mmdet3d.utils import OptConfigType
from mmdet3d.utils import ConfigType, OptConfigType
from .torchsparse import IS_TORCHSPARSE_AVAILABLE
if IS_TORCHSPARSE_AVAILABLE:
......@@ -23,14 +24,15 @@ class TorchSparseConvModule(BaseModule):
kernel_size (int or Tuple[int]): Kernel_size of block.
stride (int or Tuple[int]): Stride of the first block. Defaults to 1.
dilation (int): Dilation of block. Defaults to 1.
bias (bool): Whether use bias in conv. Defaults to False.
transposed (bool): Whether use transposed convolution operator.
Defaults to False.
norm_cfg (:obj:`ConfigDict` or dict): The config of normalization.
init_cfg (:obj:`ConfigDict` or dict, optional): Initialization config.
Defaults to None.
"""
def __init__(
self,
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Sequence[int]],
......@@ -38,22 +40,30 @@ class TorchSparseConvModule(BaseModule):
dilation: int = 1,
bias: bool = False,
transposed: bool = False,
norm_cfg: ConfigType = dict(type='TorchSparseBN'),
act_cfg: ConfigType = dict(
type='TorchSparseReLU', inplace=True),
init_cfg: OptConfigType = None,
) -> None:
**kwargs) -> None:
super().__init__(init_cfg)
self.net = nn.Sequential(
layers = [
spnn.Conv3d(in_channels, out_channels, kernel_size, stride,
dilation, bias, transposed),
spnn.BatchNorm(out_channels),
spnn.ReLU(inplace=True),
)
dilation, bias, transposed)
]
if norm_cfg is not None:
_, norm = build_norm_layer(norm_cfg, out_channels)
layers.append(norm)
if act_cfg is not None:
activation = build_activation_layer(act_cfg)
layers.append(activation)
self.net = nn.Sequential(*layers)
def forward(self, x: SparseTensor) -> SparseTensor:
out = self.net(x)
return out
class TorchSparseResidualBlock(BaseModule):
class TorchSparseBasicBlock(BaseModule):
"""Torchsparse residual basic block for MinkUNet.
Args:
......@@ -62,38 +72,114 @@ class TorchSparseResidualBlock(BaseModule):
kernel_size (int or Tuple[int]): Kernel_size of block.
stride (int or Tuple[int]): Stride of the first block. Defaults to 1.
dilation (int): Dilation of block. Defaults to 1.
bias (bool): Whether use bias in conv. Defaults to False.
norm_cfg (:obj:`ConfigDict` or dict): The config of normalization.
init_cfg (:obj:`ConfigDict` or dict, optional): Initialization config.
Defaults to None.
"""
def __init__(
self,
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Sequence[int]],
kernel_size: Union[int, Sequence[int]] = 3,
stride: Union[int, Sequence[int]] = 1,
dilation: int = 1,
bias: bool = False,
norm_cfg: ConfigType = dict(type='TorchSparseBN'),
init_cfg: OptConfigType = None,
) -> None:
**kwargs) -> None:
super().__init__(init_cfg)
_, norm1 = build_norm_layer(norm_cfg, out_channels)
_, norm2 = build_norm_layer(norm_cfg, out_channels)
self.net = nn.Sequential(
spnn.Conv3d(in_channels, out_channels, kernel_size, stride,
dilation, bias),
spnn.BatchNorm(out_channels),
spnn.ReLU(inplace=True),
dilation, bias), norm1, spnn.ReLU(inplace=True),
spnn.Conv3d(
out_channels,
out_channels,
kernel_size,
stride=1,
dilation=dilation,
bias=bias), norm2)
if in_channels == out_channels and stride == 1:
self.downsample = nn.Identity()
else:
_, norm3 = build_norm_layer(norm_cfg, out_channels)
self.downsample = nn.Sequential(
spnn.Conv3d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
dilation=dilation,
bias=bias), norm3)
self.relu = spnn.ReLU(inplace=True)
def forward(self, x: SparseTensor) -> SparseTensor:
out = self.relu(self.net(x) + self.downsample(x))
return out
class TorchSparseBottleneck(BaseModule):
"""Torchsparse residual basic block for MinkUNet.
Args:
in_channels (int): In channels of block.
out_channels (int): Out channels of block.
kernel_size (int or Tuple[int]): Kernel_size of block.
stride (int or Tuple[int]): Stride of the second block. Defaults to 1.
dilation (int): Dilation of block. Defaults to 1.
bias (bool): Whether use bias in conv. Defaults to False.
norm_cfg (:obj:`ConfigDict` or dict): The config of normalization.
init_cfg (:obj:`ConfigDict` or dict, optional): Initialization config.
Defaults to None.
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Sequence[int]] = 3,
stride: Union[int, Sequence[int]] = 1,
dilation: int = 1,
bias: bool = False,
norm_cfg: ConfigType = dict(type='TorchSparseBN'),
init_cfg: OptConfigType = None,
**kwargs) -> None:
super().__init__(init_cfg)
_, norm1 = build_norm_layer(norm_cfg, out_channels)
_, norm2 = build_norm_layer(norm_cfg, out_channels)
_, norm3 = build_norm_layer(norm_cfg, out_channels)
self.net = nn.Sequential(
spnn.Conv3d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
dilation=dilation,
bias=bias), norm1, spnn.ReLU(inplace=True),
spnn.Conv3d(
out_channels,
out_channels,
kernel_size,
stride,
dilation=dilation,
bias=bias), norm2, spnn.ReLU(inplace=True),
spnn.Conv3d(
out_channels,
out_channels,
kernel_size=1,
stride=1,
dilation=dilation,
bias=bias),
spnn.BatchNorm(out_channels),
)
bias=bias), norm3)
if in_channels == out_channels and stride == 1:
self.downsample = nn.Identity()
else:
_, norm4 = build_norm_layer(norm_cfg, out_channels)
self.downsample = nn.Sequential(
spnn.Conv3d(
in_channels,
......@@ -101,9 +187,7 @@ class TorchSparseResidualBlock(BaseModule):
kernel_size=1,
stride=stride,
dilation=dilation,
bias=bias),
spnn.BatchNorm(out_channels),
)
bias=bias), norm4)
self.relu = spnn.ReLU(inplace=True)
......
......@@ -7,8 +7,9 @@ from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
if IS_TORCHSPARSE_AVAILABLE:
from torchsparse import SparseTensor
from mmdet3d.models.layers.torchsparse_block import (
TorchSparseConvModule, TorchSparseResidualBlock)
from mmdet3d.models.layers.torchsparse_block import (TorchSparseBasicBlock,
TorchSparseBottleneck,
TorchSparseConvModule)
else:
pytest.skip('test requires Torchsparse', allow_module_level=True)
......@@ -53,8 +54,11 @@ def test_TorchsparseResidualBlock():
# test
input_sp_tensor = SparseTensor(voxel_features, coordinates)
sparse_block0 = TorchSparseResidualBlock(4, 16, kernel_size=3).cuda()
sparse_block0 = TorchSparseBasicBlock(4, 16, kernel_size=3).cuda()
sparse_block1 = TorchSparseBottleneck(4, 16, kernel_size=3).cuda()
# test forward
out_features = sparse_block0(input_sp_tensor)
assert out_features.F.shape == torch.Size([4, 16])
out_features0 = sparse_block0(input_sp_tensor)
out_features1 = sparse_block1(input_sp_tensor)
assert out_features0.F.shape == torch.Size([4, 16])
assert out_features1.F.shape == torch.Size([4, 16])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment