Unverified Commit 65e28327 authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Enhance] Support more Torchsparse blocks (#2532)

* enhance

* enhance

* fix __init__

* fix minku backbone
parent b6571d22
...@@ -5,8 +5,7 @@ from mmengine.model import BaseModule ...@@ -5,8 +5,7 @@ from mmengine.model import BaseModule
from mmengine.registry import MODELS from mmengine.registry import MODELS
from torch import Tensor, nn from torch import Tensor, nn
from mmdet3d.models.layers import (TorchSparseConvModule, from mmdet3d.models.layers import TorchSparseBasicBlock, TorchSparseConvModule
TorchSparseResidualBlock)
from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
from mmdet3d.utils import OptMultiConfig from mmdet3d.utils import OptMultiConfig
...@@ -64,11 +63,11 @@ class MinkUNetBackbone(BaseModule): ...@@ -64,11 +63,11 @@ class MinkUNetBackbone(BaseModule):
encoder_channels[i], encoder_channels[i],
kernel_size=2, kernel_size=2,
stride=2), stride=2),
TorchSparseResidualBlock( TorchSparseBasicBlock(
encoder_channels[i], encoder_channels[i],
encoder_channels[i + 1], encoder_channels[i + 1],
kernel_size=3), kernel_size=3),
TorchSparseResidualBlock( TorchSparseBasicBlock(
encoder_channels[i + 1], encoder_channels[i + 1],
encoder_channels[i + 1], encoder_channels[i + 1],
kernel_size=3))) kernel_size=3)))
...@@ -82,11 +81,11 @@ class MinkUNetBackbone(BaseModule): ...@@ -82,11 +81,11 @@ class MinkUNetBackbone(BaseModule):
stride=2, stride=2,
transposed=True), transposed=True),
nn.Sequential( nn.Sequential(
TorchSparseResidualBlock( TorchSparseBasicBlock(
decoder_channels[i + 1] + encoder_channels[-2 - i], decoder_channels[i + 1] + encoder_channels[-2 - i],
decoder_channels[i + 1], decoder_channels[i + 1],
kernel_size=3), kernel_size=3),
TorchSparseResidualBlock( TorchSparseBasicBlock(
decoder_channels[i + 1], decoder_channels[i + 1],
decoder_channels[i + 1], decoder_channels[i + 1],
kernel_size=3)) kernel_size=3))
......
...@@ -14,7 +14,8 @@ from .pointnet_modules import (PAConvCUDASAModule, PAConvCUDASAModuleMSG, ...@@ -14,7 +14,8 @@ from .pointnet_modules import (PAConvCUDASAModule, PAConvCUDASAModuleMSG,
build_sa_module) build_sa_module)
from .sparse_block import (SparseBasicBlock, SparseBottleneck, from .sparse_block import (SparseBasicBlock, SparseBottleneck,
make_sparse_convmodule) make_sparse_convmodule)
from .torchsparse_block import TorchSparseConvModule, TorchSparseResidualBlock from .torchsparse_block import (TorchSparseBasicBlock, TorchSparseBottleneck,
TorchSparseConvModule)
from .transformer import GroupFree3DMHA from .transformer import GroupFree3DMHA
from .vote_module import VoteModule from .vote_module import VoteModule
...@@ -28,5 +29,5 @@ __all__ = [ ...@@ -28,5 +29,5 @@ __all__ = [
'nms_normal_bev', 'build_sa_module', 'PointSAModuleMSG', 'PointSAModule', 'nms_normal_bev', 'build_sa_module', 'PointSAModuleMSG', 'PointSAModule',
'PointFPModule', 'PAConvSAModule', 'PAConvSAModuleMSG', 'PointFPModule', 'PAConvSAModule', 'PAConvSAModuleMSG',
'PAConvCUDASAModule', 'PAConvCUDASAModuleMSG', 'TorchSparseConvModule', 'PAConvCUDASAModule', 'PAConvCUDASAModuleMSG', 'TorchSparseConvModule',
'TorchSparseResidualBlock' 'TorchSparseBasicBlock', 'TorchSparseBottleneck'
] ]
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmengine.registry import MODELS from mmengine.registry import MODELS
...@@ -7,12 +8,21 @@ def register_torchsparse() -> bool: ...@@ -7,12 +8,21 @@ def register_torchsparse() -> bool:
try: try:
from torchsparse.nn import (BatchNorm, Conv3d, GroupNorm, LeakyReLU, from torchsparse.nn import (BatchNorm, Conv3d, GroupNorm, LeakyReLU,
ReLU) ReLU)
from torchsparse.nn.utils import fapply
from torchsparse.tensor import SparseTensor
except ImportError: except ImportError:
return False return False
else: else:
class SyncBatchNorm(nn.SyncBatchNorm):
def forward(self, input: SparseTensor) -> SparseTensor:
return fapply(input, super().forward)
MODELS._register_module(Conv3d, 'TorchSparseConv3d') MODELS._register_module(Conv3d, 'TorchSparseConv3d')
MODELS._register_module(BatchNorm, 'TorchSparseBatchNorm') MODELS._register_module(BatchNorm, 'TorchSparseBN')
MODELS._register_module(GroupNorm, 'TorchSparseGroupNorm') MODELS._register_module(SyncBatchNorm, 'TorchSparseSyncBN')
MODELS._register_module(GroupNorm, 'TorchSparseGN')
MODELS._register_module(ReLU, 'TorchSparseReLU') MODELS._register_module(ReLU, 'TorchSparseReLU')
MODELS._register_module(LeakyReLU, 'TorchSparseLeakyReLU') MODELS._register_module(LeakyReLU, 'TorchSparseLeakyReLU')
return True return True
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Union from typing import Sequence, Union
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmengine.model import BaseModule from mmengine.model import BaseModule
from torch import nn from torch import nn
from mmdet3d.utils import OptConfigType from mmdet3d.utils import ConfigType, OptConfigType
from .torchsparse import IS_TORCHSPARSE_AVAILABLE from .torchsparse import IS_TORCHSPARSE_AVAILABLE
if IS_TORCHSPARSE_AVAILABLE: if IS_TORCHSPARSE_AVAILABLE:
...@@ -23,37 +24,46 @@ class TorchSparseConvModule(BaseModule): ...@@ -23,37 +24,46 @@ class TorchSparseConvModule(BaseModule):
kernel_size (int or Tuple[int]): Kernel_size of block. kernel_size (int or Tuple[int]): Kernel_size of block.
stride (int or Tuple[int]): Stride of the first block. Defaults to 1. stride (int or Tuple[int]): Stride of the first block. Defaults to 1.
dilation (int): Dilation of block. Defaults to 1. dilation (int): Dilation of block. Defaults to 1.
bias (bool): Whether use bias in conv. Defaults to False.
transposed (bool): Whether use transposed convolution operator. transposed (bool): Whether use transposed convolution operator.
Defaults to False. Defaults to False.
norm_cfg (:obj:`ConfigDict` or dict): The config of normalization.
init_cfg (:obj:`ConfigDict` or dict, optional): Initialization config. init_cfg (:obj:`ConfigDict` or dict, optional): Initialization config.
Defaults to None. Defaults to None.
""" """
def __init__( def __init__(self,
self, in_channels: int,
in_channels: int, out_channels: int,
out_channels: int, kernel_size: Union[int, Sequence[int]],
kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1,
stride: Union[int, Sequence[int]] = 1, dilation: int = 1,
dilation: int = 1, bias: bool = False,
bias: bool = False, transposed: bool = False,
transposed: bool = False, norm_cfg: ConfigType = dict(type='TorchSparseBN'),
init_cfg: OptConfigType = None, act_cfg: ConfigType = dict(
) -> None: type='TorchSparseReLU', inplace=True),
init_cfg: OptConfigType = None,
**kwargs) -> None:
super().__init__(init_cfg) super().__init__(init_cfg)
self.net = nn.Sequential( layers = [
spnn.Conv3d(in_channels, out_channels, kernel_size, stride, spnn.Conv3d(in_channels, out_channels, kernel_size, stride,
dilation, bias, transposed), dilation, bias, transposed)
spnn.BatchNorm(out_channels), ]
spnn.ReLU(inplace=True), if norm_cfg is not None:
) _, norm = build_norm_layer(norm_cfg, out_channels)
layers.append(norm)
if act_cfg is not None:
activation = build_activation_layer(act_cfg)
layers.append(activation)
self.net = nn.Sequential(*layers)
def forward(self, x: SparseTensor) -> SparseTensor: def forward(self, x: SparseTensor) -> SparseTensor:
out = self.net(x) out = self.net(x)
return out return out
class TorchSparseResidualBlock(BaseModule): class TorchSparseBasicBlock(BaseModule):
"""Torchsparse residual basic block for MinkUNet. """Torchsparse residual basic block for MinkUNet.
Args: Args:
...@@ -62,38 +72,114 @@ class TorchSparseResidualBlock(BaseModule): ...@@ -62,38 +72,114 @@ class TorchSparseResidualBlock(BaseModule):
kernel_size (int or Tuple[int]): Kernel_size of block. kernel_size (int or Tuple[int]): Kernel_size of block.
stride (int or Tuple[int]): Stride of the first block. Defaults to 1. stride (int or Tuple[int]): Stride of the first block. Defaults to 1.
dilation (int): Dilation of block. Defaults to 1. dilation (int): Dilation of block. Defaults to 1.
bias (bool): Whether use bias in conv. Defaults to False.
norm_cfg (:obj:`ConfigDict` or dict): The config of normalization.
init_cfg (:obj:`ConfigDict` or dict, optional): Initialization config. init_cfg (:obj:`ConfigDict` or dict, optional): Initialization config.
Defaults to None. Defaults to None.
""" """
def __init__( def __init__(self,
self, in_channels: int,
in_channels: int, out_channels: int,
out_channels: int, kernel_size: Union[int, Sequence[int]] = 3,
kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1,
stride: Union[int, Sequence[int]] = 1, dilation: int = 1,
dilation: int = 1, bias: bool = False,
bias: bool = False, norm_cfg: ConfigType = dict(type='TorchSparseBN'),
init_cfg: OptConfigType = None, init_cfg: OptConfigType = None,
) -> None: **kwargs) -> None:
super().__init__(init_cfg) super().__init__(init_cfg)
_, norm1 = build_norm_layer(norm_cfg, out_channels)
_, norm2 = build_norm_layer(norm_cfg, out_channels)
self.net = nn.Sequential( self.net = nn.Sequential(
spnn.Conv3d(in_channels, out_channels, kernel_size, stride, spnn.Conv3d(in_channels, out_channels, kernel_size, stride,
dilation, bias), dilation, bias), norm1, spnn.ReLU(inplace=True),
spnn.BatchNorm(out_channels), spnn.Conv3d(
spnn.ReLU(inplace=True), out_channels,
out_channels,
kernel_size,
stride=1,
dilation=dilation,
bias=bias), norm2)
if in_channels == out_channels and stride == 1:
self.downsample = nn.Identity()
else:
_, norm3 = build_norm_layer(norm_cfg, out_channels)
self.downsample = nn.Sequential(
spnn.Conv3d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
dilation=dilation,
bias=bias), norm3)
self.relu = spnn.ReLU(inplace=True)
def forward(self, x: SparseTensor) -> SparseTensor:
out = self.relu(self.net(x) + self.downsample(x))
return out
class TorchSparseBottleneck(BaseModule):
"""Torchsparse residual basic block for MinkUNet.
Args:
in_channels (int): In channels of block.
out_channels (int): Out channels of block.
kernel_size (int or Tuple[int]): Kernel_size of block.
stride (int or Tuple[int]): Stride of the second block. Defaults to 1.
dilation (int): Dilation of block. Defaults to 1.
bias (bool): Whether use bias in conv. Defaults to False.
norm_cfg (:obj:`ConfigDict` or dict): The config of normalization.
init_cfg (:obj:`ConfigDict` or dict, optional): Initialization config.
Defaults to None.
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Sequence[int]] = 3,
stride: Union[int, Sequence[int]] = 1,
dilation: int = 1,
bias: bool = False,
norm_cfg: ConfigType = dict(type='TorchSparseBN'),
init_cfg: OptConfigType = None,
**kwargs) -> None:
super().__init__(init_cfg)
_, norm1 = build_norm_layer(norm_cfg, out_channels)
_, norm2 = build_norm_layer(norm_cfg, out_channels)
_, norm3 = build_norm_layer(norm_cfg, out_channels)
self.net = nn.Sequential(
spnn.Conv3d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
dilation=dilation,
bias=bias), norm1, spnn.ReLU(inplace=True),
spnn.Conv3d( spnn.Conv3d(
out_channels, out_channels,
out_channels, out_channels,
kernel_size, kernel_size,
stride,
dilation=dilation,
bias=bias), norm2, spnn.ReLU(inplace=True),
spnn.Conv3d(
out_channels,
out_channels,
kernel_size=1,
stride=1, stride=1,
dilation=dilation, dilation=dilation,
bias=bias), bias=bias), norm3)
spnn.BatchNorm(out_channels),
)
if in_channels == out_channels and stride == 1: if in_channels == out_channels and stride == 1:
self.downsample = nn.Identity() self.downsample = nn.Identity()
else: else:
_, norm4 = build_norm_layer(norm_cfg, out_channels)
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
spnn.Conv3d( spnn.Conv3d(
in_channels, in_channels,
...@@ -101,9 +187,7 @@ class TorchSparseResidualBlock(BaseModule): ...@@ -101,9 +187,7 @@ class TorchSparseResidualBlock(BaseModule):
kernel_size=1, kernel_size=1,
stride=stride, stride=stride,
dilation=dilation, dilation=dilation,
bias=bias), bias=bias), norm4)
spnn.BatchNorm(out_channels),
)
self.relu = spnn.ReLU(inplace=True) self.relu = spnn.ReLU(inplace=True)
......
...@@ -7,8 +7,9 @@ from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE ...@@ -7,8 +7,9 @@ from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
if IS_TORCHSPARSE_AVAILABLE: if IS_TORCHSPARSE_AVAILABLE:
from torchsparse import SparseTensor from torchsparse import SparseTensor
from mmdet3d.models.layers.torchsparse_block import ( from mmdet3d.models.layers.torchsparse_block import (TorchSparseBasicBlock,
TorchSparseConvModule, TorchSparseResidualBlock) TorchSparseBottleneck,
TorchSparseConvModule)
else: else:
pytest.skip('test requires Torchsparse', allow_module_level=True) pytest.skip('test requires Torchsparse', allow_module_level=True)
...@@ -53,8 +54,11 @@ def test_TorchsparseResidualBlock(): ...@@ -53,8 +54,11 @@ def test_TorchsparseResidualBlock():
# test # test
input_sp_tensor = SparseTensor(voxel_features, coordinates) input_sp_tensor = SparseTensor(voxel_features, coordinates)
sparse_block0 = TorchSparseResidualBlock(4, 16, kernel_size=3).cuda() sparse_block0 = TorchSparseBasicBlock(4, 16, kernel_size=3).cuda()
sparse_block1 = TorchSparseBottleneck(4, 16, kernel_size=3).cuda()
# test forward # test forward
out_features = sparse_block0(input_sp_tensor) out_features0 = sparse_block0(input_sp_tensor)
assert out_features.F.shape == torch.Size([4, 16]) out_features1 = sparse_block1(input_sp_tensor)
assert out_features0.F.shape == torch.Size([4, 16])
assert out_features1.F.shape == torch.Size([4, 16])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment