Unverified Commit 3365d7e5 authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Feature]Add Lidarseg benchmark (#2530)

* enhance minkunet

* add 2x

* add config

* add flip...

* add bottleneck

* add spvcnn & cylinder3d

* add spvcnn & cylinder3d

* refactor minkunet & spvcnn

* add minkv2

* fix mink34 shared res block

* add mink spconv

* fix spconv int32 max bug

* fix spconv int32 max bug2

* add minkowski backends

* rename config

* fix minkv2 config

* fix max voxel bug

* add checkpointhook mink18

* add backbone docstring

* fix torchsparse uninstall bug

* remove ME

* fix ut

* fix cylinder3d config
parent 3fa4dc1a
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings
from functools import partial
from typing import List from typing import List
import torch
from mmengine.model import BaseModule from mmengine.model import BaseModule
from mmengine.registry import MODELS from mmengine.registry import MODELS
from torch import Tensor, nn from torch import Tensor, nn
from mmdet3d.models.layers import TorchSparseBasicBlock, TorchSparseConvModule from mmdet3d.models.layers.minkowski_engine_block import (
IS_MINKOWSKI_ENGINE_AVAILABLE, MinkowskiBasicBlock, MinkowskiBottleneck,
MinkowskiConvModule)
from mmdet3d.models.layers.sparse_block import (SparseBasicBlock,
SparseBottleneck,
make_sparse_convmodule,
replace_feature)
from mmdet3d.models.layers.spconv import IS_SPCONV2_AVAILABLE
from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
from mmdet3d.models.layers.torchsparse_block import (TorchSparseBasicBlock,
TorchSparseBottleneck,
TorchSparseConvModule)
from mmdet3d.utils import OptMultiConfig from mmdet3d.utils import OptMultiConfig
if IS_SPCONV2_AVAILABLE:
from spconv.pytorch import SparseConvTensor
else:
from mmcv.ops import SparseConvTensor
if IS_TORCHSPARSE_AVAILABLE: if IS_TORCHSPARSE_AVAILABLE:
import torchsparse import torchsparse
from torchsparse.tensor import SparseTensor
else: if IS_MINKOWSKI_ENGINE_AVAILABLE:
SparseTensor = None import MinkowskiEngine as ME
@MODELS.register_module() @MODELS.register_module()
...@@ -27,12 +45,16 @@ class MinkUNetBackbone(BaseModule): ...@@ -27,12 +45,16 @@ class MinkUNetBackbone(BaseModule):
Defaults to 4. Defaults to 4.
base_channels (int): The input channels for first encoder layer. base_channels (int): The input channels for first encoder layer.
Defaults to 32. Defaults to 32.
num_stages (int): Number of stages in encoder and decoder.
Defaults to 4.
encoder_channels (List[int]): Convolutional channels of each encode encoder_channels (List[int]): Convolutional channels of each encode
layer. Defaults to [32, 64, 128, 256]. layer. Defaults to [32, 64, 128, 256].
encoder_blocks (List[int]): Number of blocks in each encode layer.
decoder_channels (List[int]): Convolutional channels of each decode decoder_channels (List[int]): Convolutional channels of each decode
layer. Defaults to [256, 128, 96, 96]. layer. Defaults to [256, 128, 96, 96].
num_stages (int): Number of stages in encoder and decoder. decoder_blocks (List[int]): Number of blocks in each decode layer.
Defaults to 4. block_type (str): Type of block in encoder and decoder.
sparseconv_backend (str): Sparse convolutional backend.
init_cfg (dict or :obj:`ConfigDict` or List[dict or :obj:`ConfigDict`] init_cfg (dict or :obj:`ConfigDict` or List[dict or :obj:`ConfigDict`]
, optional): Initialization config dict. , optional): Initialization config dict.
""" """
...@@ -40,58 +62,141 @@ class MinkUNetBackbone(BaseModule): ...@@ -40,58 +62,141 @@ class MinkUNetBackbone(BaseModule):
def __init__(self, def __init__(self,
in_channels: int = 4, in_channels: int = 4,
base_channels: int = 32, base_channels: int = 32,
num_stages: int = 4,
encoder_channels: List[int] = [32, 64, 128, 256], encoder_channels: List[int] = [32, 64, 128, 256],
encoder_blocks: List[int] = [2, 2, 2, 2],
decoder_channels: List[int] = [256, 128, 96, 96], decoder_channels: List[int] = [256, 128, 96, 96],
num_stages: int = 4, decoder_blocks: List[int] = [2, 2, 2, 2],
block_type: str = 'basic',
sparseconv_backend: str = 'torchsparse',
init_cfg: OptMultiConfig = None) -> None: init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg) super().__init__(init_cfg)
assert num_stages == len(encoder_channels) == len(decoder_channels) assert num_stages == len(encoder_channels) == len(decoder_channels)
assert sparseconv_backend in [
'torchsparse', 'spconv', 'minkowski'
], f'sparseconv backend: {sparseconv_backend} not supported.'
self.num_stages = num_stages self.num_stages = num_stages
self.sparseconv_backend = sparseconv_backend
if sparseconv_backend == 'torchsparse':
assert IS_TORCHSPARSE_AVAILABLE, \
'Please follow `get_started.md` to install Torchsparse.`'
input_conv = TorchSparseConvModule
encoder_conv = TorchSparseConvModule
decoder_conv = TorchSparseConvModule
residual_block = TorchSparseBasicBlock if block_type == 'basic' \
else TorchSparseBottleneck
# for torchsparse, residual branch will be implemented internally
residual_branch = None
elif sparseconv_backend == 'spconv':
if not IS_SPCONV2_AVAILABLE:
warnings.warn('Spconv 2.x is not available,'
'turn to use spconv 1.x in mmcv.')
input_conv = partial(
make_sparse_convmodule, conv_type='SubMConv3d')
encoder_conv = partial(
make_sparse_convmodule, conv_type='SparseConv3d')
decoder_conv = partial(
make_sparse_convmodule, conv_type='SparseInverseConv3d')
residual_block = SparseBasicBlock if block_type == 'basic' \
else SparseBottleneck
residual_branch = partial(
make_sparse_convmodule,
conv_type='SubMConv3d',
order=('conv', 'norm'))
elif sparseconv_backend == 'minkowski':
assert IS_MINKOWSKI_ENGINE_AVAILABLE, \
'Please follow `get_started.md` to install Minkowski Engine.`'
input_conv = MinkowskiConvModule
encoder_conv = MinkowskiConvModule
decoder_conv = partial(
MinkowskiConvModule,
conv_cfg=dict(type='MinkowskiConvNdTranspose'))
residual_block = MinkowskiBasicBlock if block_type == 'basic' \
else MinkowskiBottleneck
residual_branch = partial(MinkowskiConvModule, act_cfg=None)
self.conv_input = nn.Sequential( self.conv_input = nn.Sequential(
TorchSparseConvModule(in_channels, base_channels, kernel_size=3), input_conv(
TorchSparseConvModule(base_channels, base_channels, kernel_size=3)) in_channels,
base_channels,
kernel_size=3,
padding=1,
indice_key='subm0'),
input_conv(
base_channels,
base_channels,
kernel_size=3,
padding=1,
indice_key='subm0'))
self.encoder = nn.ModuleList() self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList() self.decoder = nn.ModuleList()
encoder_channels.insert(0, base_channels) encoder_channels.insert(0, base_channels)
decoder_channels.insert(0, encoder_channels[-1]) decoder_channels.insert(0, encoder_channels[-1])
for i in range(num_stages): for i in range(num_stages):
self.encoder.append( encoder_layer = [
nn.Sequential( encoder_conv(
TorchSparseConvModule( encoder_channels[i],
encoder_channels[i], encoder_channels[i],
encoder_channels[i], kernel_size=2,
kernel_size=2, stride=2,
stride=2), indice_key=f'spconv{i+1}')
TorchSparseBasicBlock( ]
encoder_channels[i], for j in range(encoder_blocks[i]):
encoder_channels[i + 1], if j == 0 and encoder_channels[i] != encoder_channels[i + 1]:
kernel_size=3), encoder_layer.append(
TorchSparseBasicBlock( residual_block(
encoder_channels[i + 1], encoder_channels[i],
encoder_channels[i + 1], encoder_channels[i + 1],
kernel_size=3))) downsample=residual_branch(
encoder_channels[i],
encoder_channels[i + 1],
kernel_size=1)
if residual_branch is not None else None,
indice_key=f'subm{i+1}'))
else:
encoder_layer.append(
residual_block(
encoder_channels[i + 1],
encoder_channels[i + 1],
indice_key=f'subm{i+1}'))
self.encoder.append(nn.Sequential(*encoder_layer))
self.decoder.append( decoder_layer = [
nn.ModuleList([ decoder_conv(
TorchSparseConvModule( decoder_channels[i],
decoder_channels[i], decoder_channels[i + 1],
decoder_channels[i + 1], kernel_size=2,
kernel_size=2, stride=2,
stride=2, transposed=True,
transposed=True), indice_key=f'spconv{num_stages-i}')
nn.Sequential( ]
TorchSparseBasicBlock( for j in range(decoder_blocks[i]):
if j == 0:
decoder_layer.append(
residual_block(
decoder_channels[i + 1] + encoder_channels[-2 - i], decoder_channels[i + 1] + encoder_channels[-2 - i],
decoder_channels[i + 1], decoder_channels[i + 1],
kernel_size=3), downsample=residual_branch(
TorchSparseBasicBlock( decoder_channels[i + 1] +
encoder_channels[-2 - i],
decoder_channels[i + 1],
kernel_size=1)
if residual_branch is not None else None,
indice_key=f'subm{num_stages-i-1}'))
else:
decoder_layer.append(
residual_block(
decoder_channels[i + 1], decoder_channels[i + 1],
decoder_channels[i + 1], decoder_channels[i + 1],
kernel_size=3)) indice_key=f'subm{num_stages-i-1}'))
])) self.decoder.append(
nn.ModuleList(
[decoder_layer[0],
nn.Sequential(*decoder_layer[1:])]))
def forward(self, voxel_features: Tensor, coors: Tensor) -> SparseTensor: def forward(self, voxel_features: Tensor, coors: Tensor) -> Tensor:
"""Forward function. """Forward function.
Args: Args:
...@@ -100,9 +205,18 @@ class MinkUNetBackbone(BaseModule): ...@@ -100,9 +205,18 @@ class MinkUNetBackbone(BaseModule):
the columns in the order of (x_idx, y_idx, z_idx, batch_idx). the columns in the order of (x_idx, y_idx, z_idx, batch_idx).
Returns: Returns:
SparseTensor: Backbone features. Tensor: Backbone features.
""" """
x = torchsparse.SparseTensor(voxel_features, coors) if self.sparseconv_backend == 'torchsparse':
x = torchsparse.SparseTensor(voxel_features, coors)
elif self.sparseconv_backend == 'spconv':
spatial_shape = coors.max(0)[0][1:] + 1
batch_size = int(coors[-1, 0]) + 1
x = SparseConvTensor(voxel_features, coors, spatial_shape,
batch_size)
elif self.sparseconv_backend == 'minkowski':
x = ME.SparseTensor(voxel_features, coors)
x = self.conv_input(x) x = self.conv_input(x)
laterals = [x] laterals = [x]
for encoder_layer in self.encoder: for encoder_layer in self.encoder:
...@@ -113,8 +227,19 @@ class MinkUNetBackbone(BaseModule): ...@@ -113,8 +227,19 @@ class MinkUNetBackbone(BaseModule):
decoder_outs = [] decoder_outs = []
for i, decoder_layer in enumerate(self.decoder): for i, decoder_layer in enumerate(self.decoder):
x = decoder_layer[0](x) x = decoder_layer[0](x)
x = torchsparse.cat((x, laterals[i]))
if self.sparseconv_backend == 'torchsparse':
x = torchsparse.cat((x, laterals[i]))
elif self.sparseconv_backend == 'spconv':
x = replace_feature(
x, torch.cat((x.features, laterals[i].features), dim=1))
elif self.sparseconv_backend == 'minkowski':
x = ME.cat(x, laterals[i])
x = decoder_layer[1](x) x = decoder_layer[1](x)
decoder_outs.append(x) decoder_outs.append(x)
return decoder_outs[-1] if self.sparseconv_backend == 'spconv':
return decoder_outs[-1].features
else:
return decoder_outs[-1].F
...@@ -6,7 +6,6 @@ from mmengine.registry import MODELS ...@@ -6,7 +6,6 @@ from mmengine.registry import MODELS
from torch import Tensor, nn from torch import Tensor, nn
from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
from mmdet3d.utils import OptMultiConfig
from .minkunet_backbone import MinkUNetBackbone from .minkunet_backbone import MinkUNetBackbone
if IS_TORCHSPARSE_AVAILABLE: if IS_TORCHSPARSE_AVAILABLE:
...@@ -29,13 +28,14 @@ class SPVCNNBackbone(MinkUNetBackbone): ...@@ -29,13 +28,14 @@ class SPVCNNBackbone(MinkUNetBackbone):
Defaults to 4. Defaults to 4.
base_channels (int): The input channels for first encoder layer. base_channels (int): The input channels for first encoder layer.
Defaults to 32. Defaults to 32.
num_stages (int): Number of stages in encoder and decoder.
Defaults to 4.
encoder_channels (List[int]): Convolutional channels of each encode encoder_channels (List[int]): Convolutional channels of each encode
layer. Defaults to [32, 64, 128, 256]. layer. Defaults to [32, 64, 128, 256].
decoder_channels (List[int]): Convolutional channels of each decode decoder_channels (List[int]): Convolutional channels of each decode
layer. Defaults to [256, 128, 96, 96]. layer. Defaults to [256, 128, 96, 96].
num_stages (int): Number of stages in encoder and decoder.
Defaults to 4.
drop_ratio (float): Dropout ratio of voxel features. Defaults to 0.3. drop_ratio (float): Dropout ratio of voxel features. Defaults to 0.3.
sparseconv_backend (str): Sparse convolution backend.
init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`] init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`]
, optional): Initialization config dict. Defaults to None. , optional): Initialization config dict. Defaults to None.
""" """
...@@ -43,18 +43,24 @@ class SPVCNNBackbone(MinkUNetBackbone): ...@@ -43,18 +43,24 @@ class SPVCNNBackbone(MinkUNetBackbone):
def __init__(self, def __init__(self,
in_channels: int = 4, in_channels: int = 4,
base_channels: int = 32, base_channels: int = 32,
num_stages: int = 4,
encoder_channels: Sequence[int] = [32, 64, 128, 256], encoder_channels: Sequence[int] = [32, 64, 128, 256],
decoder_channels: Sequence[int] = [256, 128, 96, 96], decoder_channels: Sequence[int] = [256, 128, 96, 96],
num_stages: int = 4,
drop_ratio: float = 0.3, drop_ratio: float = 0.3,
init_cfg: OptMultiConfig = None) -> None: sparseconv_backend: str = 'torchsparse',
**kwargs) -> None:
assert num_stages == 4, 'SPVCNN backbone only supports 4 stages.'
assert sparseconv_backend == 'torchsparse', \
f'SPVCNN backbone only supports torchsparse backend, but got ' \
f'sparseconv backend: {sparseconv_backend}.'
super().__init__( super().__init__(
in_channels=in_channels, in_channels=in_channels,
base_channels=base_channels, base_channels=base_channels,
num_stages=num_stages,
encoder_channels=encoder_channels, encoder_channels=encoder_channels,
decoder_channels=decoder_channels, decoder_channels=decoder_channels,
num_stages=num_stages, sparseconv_backend=sparseconv_backend,
init_cfg=init_cfg) **kwargs)
self.point_transforms = nn.ModuleList([ self.point_transforms = nn.ModuleList([
nn.Sequential( nn.Sequential(
...@@ -69,7 +75,7 @@ class SPVCNNBackbone(MinkUNetBackbone): ...@@ -69,7 +75,7 @@ class SPVCNNBackbone(MinkUNetBackbone):
]) ])
self.dropout = nn.Dropout(drop_ratio, True) self.dropout = nn.Dropout(drop_ratio, True)
def forward(self, voxel_features: Tensor, coors: Tensor) -> PointTensor: def forward(self, voxel_features: Tensor, coors: Tensor) -> Tensor:
"""Forward function. """Forward function.
Args: Args:
...@@ -82,19 +88,19 @@ class SPVCNNBackbone(MinkUNetBackbone): ...@@ -82,19 +88,19 @@ class SPVCNNBackbone(MinkUNetBackbone):
""" """
voxels = SparseTensor(voxel_features, coors) voxels = SparseTensor(voxel_features, coors)
points = PointTensor(voxels.F, voxels.C.float()) points = PointTensor(voxels.F, voxels.C.float())
voxels = self.initial_voxelize(points) voxels = initial_voxelize(points)
voxels = self.conv_input(voxels) voxels = self.conv_input(voxels)
points = self.voxel_to_point(voxels, points) points = voxel_to_point(voxels, points)
voxels = self.point_to_voxel(voxels, points) voxels = point_to_voxel(voxels, points)
laterals = [voxels] laterals = [voxels]
for encoder in self.encoder: for encoder in self.encoder:
voxels = encoder(voxels) voxels = encoder(voxels)
laterals.append(voxels) laterals.append(voxels)
laterals = laterals[:-1][::-1] laterals = laterals[:-1][::-1]
points = self.voxel_to_point(voxels, points, self.point_transforms[0]) points = voxel_to_point(voxels, points, self.point_transforms[0])
voxels = self.point_to_voxel(voxels, points) voxels = point_to_voxel(voxels, points)
voxels.F = self.dropout(voxels.F) voxels.F = self.dropout(voxels.F)
decoder_outs = [] decoder_outs = []
...@@ -104,134 +110,188 @@ class SPVCNNBackbone(MinkUNetBackbone): ...@@ -104,134 +110,188 @@ class SPVCNNBackbone(MinkUNetBackbone):
voxels = decoder[1](voxels) voxels = decoder[1](voxels)
decoder_outs.append(voxels) decoder_outs.append(voxels)
if i == 1: if i == 1:
points = self.voxel_to_point(voxels, points, points = voxel_to_point(voxels, points,
self.point_transforms[1]) self.point_transforms[1])
voxels = self.point_to_voxel(voxels, points) voxels = point_to_voxel(voxels, points)
voxels.F = self.dropout(voxels.F) voxels.F = self.dropout(voxels.F)
points = self.voxel_to_point(voxels, points, self.point_transforms[2]) points = voxel_to_point(voxels, points, self.point_transforms[2])
return points return points.F
def initial_voxelize(self, points: PointTensor) -> SparseTensor:
"""Voxelization again based on input PointTensor.
Args: @MODELS.register_module()
points (PointTensor): Input points after voxelization. class MinkUNetBackboneV2(MinkUNetBackbone):
r"""MinkUNet backbone V2.
Returns: refer to https://github.com/PJLab-ADG/PCSeg/blob/master/pcseg/model/segmentor/voxel/minkunet/minkunet.py
SparseTensor: New voxels.
"""
pc_hash = F.sphash(torch.floor(points.C).int())
sparse_hash = torch.unique(pc_hash)
idx_query = F.sphashquery(pc_hash, sparse_hash)
counts = F.spcount(idx_query.int(), len(sparse_hash))
inserted_coords = F.spvoxelize(
torch.floor(points.C), idx_query, counts)
inserted_coords = torch.round(inserted_coords).int()
inserted_feat = F.spvoxelize(points.F, idx_query, counts)
new_tensor = SparseTensor(inserted_feat, inserted_coords, 1)
new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords)
points.additional_features['idx_query'][1] = idx_query
points.additional_features['counts'][1] = counts
return new_tensor
def voxel_to_point(self,
voxels: SparseTensor,
points: PointTensor,
point_transform: Optional[nn.Module] = None,
nearest: bool = False) -> PointTensor:
"""Feed voxel features to points.
Args: Args:
voxels (SparseTensor): Input voxels. sparseconv_backend (str): Sparse convolution backend.
points (PointTensor): Input points. """ # noqa: E501
point_transform (nn.Module, optional): Point transform module
for input point features. Defaults to None.
nearest (bool): Whether to use nearest neighbor interpolation.
Defaults to False.
Returns: def __init__(self,
PointTensor: Points with new features. sparseconv_backend: str = 'torchsparse',
""" **kwargs) -> None:
if points.idx_query is None or points.weights is None or \ assert sparseconv_backend == 'torchsparse', \
points.idx_query.get(voxels.s) is None or \ f'SPVCNN backbone only supports torchsparse backend, but got ' \
points.weights.get(voxels.s) is None: f'sparseconv backend: {sparseconv_backend}.'
offsets = get_kernel_offsets( super().__init__(sparseconv_backend=sparseconv_backend, **kwargs)
2, voxels.s, 1, device=points.F.device)
old_hash = F.sphash( def forward(self, voxel_features: Tensor, coors: Tensor) -> Tensor:
torch.cat([ """Forward function.
torch.floor(points.C[:, :3] / voxels.s[0]).int() *
voxels.s[0], points.C[:, -1].int().view(-1, 1)
], 1), offsets)
pc_hash = F.sphash(voxels.C.to(points.F.device))
idx_query = F.sphashquery(old_hash, pc_hash)
weights = F.calc_ti_weights(
points.C, idx_query,
scale=voxels.s[0]).transpose(0, 1).contiguous()
idx_query = idx_query.transpose(0, 1).contiguous()
if nearest:
weights[:, 1:] = 0.
idx_query[:, 1:] = -1
new_features = F.spdevoxelize(voxels.F, idx_query, weights)
new_tensor = PointTensor(
new_features,
points.C,
idx_query=points.idx_query,
weights=points.weights)
new_tensor.additional_features = points.additional_features
new_tensor.idx_query[voxels.s] = idx_query
new_tensor.weights[voxels.s] = weights
points.idx_query[voxels.s] = idx_query
points.weights[voxels.s] = weights
else:
new_features = F.spdevoxelize(voxels.F,
points.idx_query.get(voxels.s),
points.weights.get(voxels.s))
new_tensor = PointTensor(
new_features,
points.C,
idx_query=points.idx_query,
weights=points.weights)
new_tensor.additional_features = points.additional_features
if point_transform is not None:
new_tensor.F = new_tensor.F + point_transform(points.F)
return new_tensor
def point_to_voxel(self, voxels: SparseTensor,
points: PointTensor) -> SparseTensor:
"""Feed point features to voxels.
Args: Args:
voxels (SparseTensor): Input voxels. voxel_features (Tensor): Voxel features in shape (N, C).
points (PointTensor): Input points. coors (Tensor): Coordinates in shape (N, 4),
the columns in the order of (x_idx, y_idx, z_idx, batch_idx).
Returns: Returns:
SparseTensor: Voxels with new features. SparseTensor: Backbone features.
""" """
if points.additional_features is None or \ voxels = SparseTensor(voxel_features, coors)
points.additional_features.get('idx_query') is None or \ points = PointTensor(voxels.F, voxels.C.float())
points.additional_features['idx_query'].get(voxels.s) is None:
pc_hash = F.sphash( voxels = initial_voxelize(points)
torch.cat([ voxels = self.conv_input(voxels)
torch.floor(points.C[:, :3] / voxels.s[0]).int() * points = voxel_to_point(voxels, points)
voxels.s[0], points.C[:, -1].int().view(-1, 1)
], 1)) laterals = [voxels]
sparse_hash = F.sphash(voxels.C) for encoder_layer in self.encoder:
idx_query = F.sphashquery(pc_hash, sparse_hash) voxels = encoder_layer(voxels)
counts = F.spcount(idx_query.int(), voxels.C.shape[0]) laterals.append(voxels)
points.additional_features['idx_query'][voxels.s] = idx_query laterals = laterals[:-1][::-1]
points.additional_features['counts'][voxels.s] = counts points = voxel_to_point(voxels, points)
else: output_features = [points.F]
idx_query = points.additional_features['idx_query'][voxels.s]
counts = points.additional_features['counts'][voxels.s] for i, decoder_layer in enumerate(self.decoder):
voxels = decoder_layer[0](voxels)
inserted_features = F.spvoxelize(points.F, idx_query, counts) voxels = torchsparse.cat((voxels, laterals[i]))
new_tensor = SparseTensor(inserted_features, voxels.C, voxels.s) voxels = decoder_layer[1](voxels)
new_tensor.cmaps = voxels.cmaps if i % 2 == 1:
new_tensor.kmaps = voxels.kmaps points = voxel_to_point(voxels, points)
output_features.append(points.F)
return new_tensor
points.F = torch.cat(output_features, dim=1)
return points.F
def initial_voxelize(points: PointTensor) -> SparseTensor:
"""Voxelization again based on input PointTensor.
Args:
points (PointTensor): Input points after voxelization.
Returns:
SparseTensor: New voxels.
"""
pc_hash = F.sphash(torch.floor(points.C).int())
sparse_hash = torch.unique(pc_hash)
idx_query = F.sphashquery(pc_hash, sparse_hash)
counts = F.spcount(idx_query.int(), len(sparse_hash))
inserted_coords = F.spvoxelize(torch.floor(points.C), idx_query, counts)
inserted_coords = torch.round(inserted_coords).int()
inserted_feat = F.spvoxelize(points.F, idx_query, counts)
new_tensor = SparseTensor(inserted_feat, inserted_coords, 1)
new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords)
points.additional_features['idx_query'][1] = idx_query
points.additional_features['counts'][1] = counts
return new_tensor
def voxel_to_point(voxels: SparseTensor,
points: PointTensor,
point_transform: Optional[nn.Module] = None,
nearest: bool = False) -> PointTensor:
"""Feed voxel features to points.
Args:
voxels (SparseTensor): Input voxels.
points (PointTensor): Input points.
point_transform (nn.Module, optional): Point transform module
for input point features. Defaults to None.
nearest (bool): Whether to use nearest neighbor interpolation.
Defaults to False.
Returns:
PointTensor: Points with new features.
"""
if points.idx_query is None or points.weights is None or \
points.idx_query.get(voxels.s) is None or \
points.weights.get(voxels.s) is None:
offsets = get_kernel_offsets(2, voxels.s, 1, device=points.F.device)
old_hash = F.sphash(
torch.cat([
torch.floor(points.C[:, :3] / voxels.s[0]).int() * voxels.s[0],
points.C[:, -1].int().view(-1, 1)
], 1), offsets)
pc_hash = F.sphash(voxels.C.to(points.F.device))
idx_query = F.sphashquery(old_hash, pc_hash)
weights = F.calc_ti_weights(
points.C, idx_query, scale=voxels.s[0]).transpose(0,
1).contiguous()
idx_query = idx_query.transpose(0, 1).contiguous()
if nearest:
weights[:, 1:] = 0.
idx_query[:, 1:] = -1
new_features = F.spdevoxelize(voxels.F, idx_query, weights)
new_tensor = PointTensor(
new_features,
points.C,
idx_query=points.idx_query,
weights=points.weights)
new_tensor.additional_features = points.additional_features
new_tensor.idx_query[voxels.s] = idx_query
new_tensor.weights[voxels.s] = weights
points.idx_query[voxels.s] = idx_query
points.weights[voxels.s] = weights
else:
new_features = F.spdevoxelize(voxels.F, points.idx_query.get(voxels.s),
points.weights.get(voxels.s))
new_tensor = PointTensor(
new_features,
points.C,
idx_query=points.idx_query,
weights=points.weights)
new_tensor.additional_features = points.additional_features
if point_transform is not None:
new_tensor.F = new_tensor.F + point_transform(points.F)
return new_tensor
def point_to_voxel(voxels: SparseTensor, points: PointTensor) -> SparseTensor:
"""Feed point features to voxels.
Args:
voxels (SparseTensor): Input voxels.
points (PointTensor): Input points.
Returns:
SparseTensor: Voxels with new features.
"""
if points.additional_features is None or \
points.additional_features.get('idx_query') is None or \
points.additional_features['idx_query'].get(voxels.s) is None:
pc_hash = F.sphash(
torch.cat([
torch.floor(points.C[:, :3] / voxels.s[0]).int() * voxels.s[0],
points.C[:, -1].int().view(-1, 1)
], 1))
sparse_hash = F.sphash(voxels.C)
idx_query = F.sphashquery(pc_hash, sparse_hash)
counts = F.spcount(idx_query.int(), voxels.C.shape[0])
points.additional_features['idx_query'][voxels.s] = idx_query
points.additional_features['counts'][voxels.s] = counts
else:
idx_query = points.additional_features['idx_query'][voxels.s]
counts = points.additional_features['counts'][voxels.s]
inserted_features = F.spvoxelize(points.F, idx_query, counts)
new_tensor = SparseTensor(inserted_features, voxels.C, voxels.s)
new_tensor.cmaps = voxels.cmaps
new_tensor.kmaps = voxels.kmaps
return new_tensor
...@@ -49,6 +49,8 @@ class Det3DDataPreprocessor(DetDataPreprocessor): ...@@ -49,6 +49,8 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
voxelization and dynamic voxelization. Defaults to 'hard'. voxelization and dynamic voxelization. Defaults to 'hard'.
voxel_layer (dict or :obj:`ConfigDict`, optional): Voxelization layer voxel_layer (dict or :obj:`ConfigDict`, optional): Voxelization layer
config. Defaults to None. config. Defaults to None.
batch_first (bool): Whether to put the batch dimension to the first
dimension when getting voxel coordinates. Defaults to True.
max_voxels (int): Maximum number of voxels in each voxel grid. Defaults max_voxels (int): Maximum number of voxels in each voxel grid. Defaults
to None. to None.
mean (Sequence[Number], optional): The pixel mean of R, G, B channels. mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
...@@ -79,6 +81,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor): ...@@ -79,6 +81,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
voxel: bool = False, voxel: bool = False,
voxel_type: str = 'hard', voxel_type: str = 'hard',
voxel_layer: OptConfigType = None, voxel_layer: OptConfigType = None,
batch_first: bool = True,
max_voxels: Optional[int] = None, max_voxels: Optional[int] = None,
mean: Sequence[Number] = None, mean: Sequence[Number] = None,
std: Sequence[Number] = None, std: Sequence[Number] = None,
...@@ -106,6 +109,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor): ...@@ -106,6 +109,7 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
batch_augments=batch_augments) batch_augments=batch_augments)
self.voxel = voxel self.voxel = voxel
self.voxel_type = voxel_type self.voxel_type = voxel_type
self.batch_first = batch_first
self.max_voxels = max_voxels self.max_voxels = max_voxels
if voxel: if voxel:
self.voxel_layer = VoxelizationByGridShape(**voxel_layer) self.voxel_layer = VoxelizationByGridShape(**voxel_layer)
...@@ -440,8 +444,14 @@ class Det3DDataPreprocessor(DetDataPreprocessor): ...@@ -440,8 +444,14 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
= data_sample.gt_pts_seg.pts_semantic_mask[inds] = data_sample.gt_pts_seg.pts_semantic_mask[inds]
res_voxel_coors = res_coors[inds] res_voxel_coors = res_coors[inds]
res_voxels = res[inds] res_voxels = res[inds]
res_voxel_coors = F.pad( if self.batch_first:
res_voxel_coors, (0, 1), mode='constant', value=i) res_voxel_coors = F.pad(
res_voxel_coors, (1, 0), mode='constant', value=i)
data_sample.batch_idx = res_voxel_coors[:, 0]
else:
res_voxel_coors = F.pad(
res_voxel_coors, (0, 1), mode='constant', value=i)
data_sample.batch_idx = res_voxel_coors[:, -1]
data_sample.point2voxel_map = point2voxel_map.long() data_sample.point2voxel_map = point2voxel_map.long()
voxels.append(res_voxels) voxels.append(res_voxels)
coors.append(res_voxel_coors) coors.append(res_voxel_coors)
......
...@@ -5,16 +5,10 @@ import torch ...@@ -5,16 +5,10 @@ import torch
from torch import Tensor from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList from mmdet3d.structures.det3d_data_sample import SampleList
from .decode_head import Base3DDecodeHead from .decode_head import Base3DDecodeHead
if IS_TORCHSPARSE_AVAILABLE:
from torchsparse import SparseTensor
else:
SparseTensor = None
@MODELS.register_module() @MODELS.register_module()
class MinkUNetHead(Base3DDecodeHead): class MinkUNetHead(Base3DDecodeHead):
...@@ -43,12 +37,12 @@ class MinkUNetHead(Base3DDecodeHead): ...@@ -43,12 +37,12 @@ class MinkUNetHead(Base3DDecodeHead):
] ]
return torch.cat(gt_semantic_segs) return torch.cat(gt_semantic_segs)
def predict(self, inputs: SparseTensor, def predict(self, inputs: Tensor,
batch_data_samples: SampleList) -> List[Tensor]: batch_data_samples: SampleList) -> List[Tensor]:
"""Forward function for testing. """Forward function for testing.
Args: Args:
inputs (SparseTensor): Features from backone. inputs (Tensor): Features from backone.
batch_data_samples (List[:obj:`Det3DDataSample`]): The seg batch_data_samples (List[:obj:`Det3DDataSample`]): The seg
data samples. data samples.
...@@ -57,7 +51,8 @@ class MinkUNetHead(Base3DDecodeHead): ...@@ -57,7 +51,8 @@ class MinkUNetHead(Base3DDecodeHead):
""" """
seg_logits = self.forward(inputs) seg_logits = self.forward(inputs)
batch_idx = inputs.C[:, -1] batch_idx = torch.cat(
[data_samples.batch_idx for data_samples in batch_data_samples])
seg_logit_list = [] seg_logit_list = []
for i, data_sample in enumerate(batch_data_samples): for i, data_sample in enumerate(batch_data_samples):
seg_logit = seg_logits[batch_idx == i] seg_logit = seg_logits[batch_idx == i]
...@@ -66,15 +61,14 @@ class MinkUNetHead(Base3DDecodeHead): ...@@ -66,15 +61,14 @@ class MinkUNetHead(Base3DDecodeHead):
return seg_logit_list return seg_logit_list
def forward(self, x: SparseTensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
"""Forward function. """Forward function.
Args: Args:
x (SparseTensor): Features from backbone. x (Tensor): Features from backbone.
Returns: Returns:
Tensor: Segmentation map of shape [N, C]. Tensor: Segmentation map of shape [N, C].
Note that output contains all points from each batch. Note that output contains all points from each batch.
""" """
output = self.cls_seg(x.F) return self.cls_seg(x)
return output
...@@ -32,7 +32,7 @@ class PointRCNN(TwoStage3DDetector): ...@@ -32,7 +32,7 @@ class PointRCNN(TwoStage3DDetector):
train_cfg: Optional[dict] = None, train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None,
data_preprocessor: Optional[dict] = None) -> Optional: data_preprocessor: Optional[dict] = None) -> None:
super(PointRCNN, self).__init__( super(PointRCNN, self).__init__(
backbone=backbone, backbone=backbone,
neck=neck, neck=neck,
......
...@@ -31,6 +31,6 @@ __all__ = [ ...@@ -31,6 +31,6 @@ __all__ = [
'nms_normal_bev', 'build_sa_module', 'PointSAModuleMSG', 'PointSAModule', 'nms_normal_bev', 'build_sa_module', 'PointSAModuleMSG', 'PointSAModule',
'PointFPModule', 'PAConvSAModule', 'PAConvSAModuleMSG', 'PointFPModule', 'PAConvSAModule', 'PAConvSAModuleMSG',
'PAConvCUDASAModule', 'PAConvCUDASAModuleMSG', 'TorchSparseConvModule', 'PAConvCUDASAModule', 'PAConvCUDASAModuleMSG', 'TorchSparseConvModule',
'TorchSparseBasicBlock', 'TorchSparseBottleneck', 'MinkowskiBasicBlock', 'TorchSparseBasicBlock', 'TorchSparseBottleneck', 'MinkowskiConvModule',
'MinkowskiBottleneck', 'MinkowskiConvModule' 'MinkowskiBasicBlock', 'MinkowskiBottleneck'
] ]
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from torch import Tensor from torch import Tensor
from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import OptSampleList, SampleList from mmdet3d.structures.det3d_data_sample import OptSampleList, SampleList
from .encoder_decoder import EncoderDecoder3D from .encoder_decoder import EncoderDecoder3D
if IS_TORCHSPARSE_AVAILABLE:
from torchsparse import SparseTensor
else:
SparseTensor = None
@MODELS.register_module() @MODELS.register_module()
class MinkUNet(EncoderDecoder3D): class MinkUNet(EncoderDecoder3D):
...@@ -25,9 +19,6 @@ class MinkUNet(EncoderDecoder3D): ...@@ -25,9 +19,6 @@ class MinkUNet(EncoderDecoder3D):
""" """
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
if not IS_TORCHSPARSE_AVAILABLE:
raise ImportError(
'Please follow `get_started.md` to install Torchsparse.`')
super().__init__(**kwargs) super().__init__(**kwargs)
def loss(self, inputs: dict, data_samples: SampleList): def loss(self, inputs: dict, data_samples: SampleList):
...@@ -101,7 +92,7 @@ class MinkUNet(EncoderDecoder3D): ...@@ -101,7 +92,7 @@ class MinkUNet(EncoderDecoder3D):
x = self.extract_feat(batch_inputs_dict) x = self.extract_feat(batch_inputs_dict)
return self.decode_head.forward(x) return self.decode_head.forward(x)
def extract_feat(self, batch_inputs_dict: dict) -> SparseTensor: def extract_feat(self, batch_inputs_dict: dict) -> Tensor:
"""Extract features from voxels. """Extract features from voxels.
Args: Args:
......
...@@ -17,7 +17,7 @@ def test_minkunet_backbone(): ...@@ -17,7 +17,7 @@ def test_minkunet_backbone():
coordinates, features = [], [] coordinates, features = [], []
for i in range(2): for i in range(2):
c = torch.randint(0, 10, (100, 3)).int() c = torch.randint(0, 16, (100, 3)).int()
c = F.pad(c, (0, 1), mode='constant', value=i) c = F.pad(c, (0, 1), mode='constant', value=i)
coordinates.append(c) coordinates.append(c)
f = torch.rand(100, 4) f = torch.rand(100, 4)
...@@ -30,5 +30,4 @@ def test_minkunet_backbone(): ...@@ -30,5 +30,4 @@ def test_minkunet_backbone():
self.init_weights() self.init_weights()
y = self(features, coordinates) y = self(features, coordinates)
assert y.F.shape == torch.Size([200, 96]) assert y.shape == torch.Size([200, 96])
assert y.C.shape == torch.Size([200, 4])
...@@ -18,7 +18,7 @@ class TestCylinder3D(unittest.TestCase): ...@@ -18,7 +18,7 @@ class TestCylinder3D(unittest.TestCase):
DefaultScope.get_instance('test_cylinder3d', scope_name='mmdet3d') DefaultScope.get_instance('test_cylinder3d', scope_name='mmdet3d')
setup_seed(0) setup_seed(0)
cylinder3d_cfg = get_detector_cfg( cylinder3d_cfg = get_detector_cfg(
'cylinder3d/cylinder3d_4xb4_3x_semantickitti.py') 'cylinder3d/cylinder3d_4xb4-3x_semantickitti.py')
cylinder3d_cfg.decode_head['ignore_index'] = 1 cylinder3d_cfg.decode_head['ignore_index'] = 1
model = MODELS.build(cylinder3d_cfg) model = MODELS.build(cylinder3d_cfg)
num_gt_instance = 3 num_gt_instance = 3
......
...@@ -18,7 +18,7 @@ class TestSeg3DTTAModel(TestCase): ...@@ -18,7 +18,7 @@ class TestSeg3DTTAModel(TestCase):
assert hasattr(mmdet3d.models, 'Cylinder3D') assert hasattr(mmdet3d.models, 'Cylinder3D')
DefaultScope.get_instance('test_cylinder3d', scope_name='mmdet3d') DefaultScope.get_instance('test_cylinder3d', scope_name='mmdet3d')
segmentor3d_cfg = get_detector_cfg( segmentor3d_cfg = get_detector_cfg(
'cylinder3d/cylinder3d_4xb4_3x_semantickitti.py') 'cylinder3d/cylinder3d_4xb4-3x_semantickitti.py')
cfg = ConfigDict(type='Seg3DTTAModel', module=segmentor3d_cfg) cfg = ConfigDict(type='Seg3DTTAModel', module=segmentor3d_cfg)
model: Seg3DTTAModel = MODELS.build(cfg) model: Seg3DTTAModel = MODELS.build(cfg)
......
...@@ -3,11 +3,12 @@ import argparse ...@@ -3,11 +3,12 @@ import argparse
import time import time
import torch import torch
from mmcv import Config from mmengine import Config
from mmcv.parallel import MMDataParallel from mmengine.device import get_device
from mmengine.runner import load_checkpoint from mmengine.registry import init_default_scope
from mmengine.runner import Runner, autocast, load_checkpoint
from mmdet3d.registry import DATASETS, MODELS from mmdet3d.registry import MODELS
from tools.misc.fuse_conv_bn import fuse_module from tools.misc.fuse_conv_bn import fuse_module
...@@ -18,6 +19,10 @@ def parse_args(): ...@@ -18,6 +19,10 @@ def parse_args():
parser.add_argument('--samples', default=2000, help='samples to benchmark') parser.add_argument('--samples', default=2000, help='samples to benchmark')
parser.add_argument( parser.add_argument(
'--log-interval', default=50, help='interval of logging') '--log-interval', default=50, help='interval of logging')
parser.add_argument(
'--amp',
action='store_true',
help='Whether to use automatic mixed precision inference')
parser.add_argument( parser.add_argument(
'--fuse-conv-bn', '--fuse-conv-bn',
action='store_true', action='store_true',
...@@ -29,38 +34,23 @@ def parse_args(): ...@@ -29,38 +34,23 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
init_default_scope('mmdet3d')
# build config and set cudnn_benchmark
cfg = Config.fromfile(args.config) cfg = Config.fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False): if cfg.env_cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
cfg.model.pretrained = None
cfg.data.test.test_mode = True # build dataloader
dataloader = Runner.build_dataloader(cfg.test_dataloader)
# build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed) # build model and load checkpoint
dataset = DATASETS.build(cfg.data.test) model = MODELS.build(cfg.model)
# TODO fix this
def build_dataloader():
pass
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=False,
shuffle=False)
# build the model and load checkpoint
cfg.model.train_cfg = None
model = MODELS.build(cfg.model, test_cfg=cfg.get('test_cfg'))
load_checkpoint(model, args.checkpoint, map_location='cpu') load_checkpoint(model, args.checkpoint, map_location='cpu')
if args.fuse_conv_bn: if args.fuse_conv_bn:
model = fuse_module(model) model = fuse_module(model)
model.to(get_device())
model = MMDataParallel(model, device_ids=[0])
model.eval() model.eval()
# the first several iterations may be very slow so skip them # the first several iterations may be very slow so skip them
...@@ -68,13 +58,13 @@ def main(): ...@@ -68,13 +58,13 @@ def main():
pure_inf_time = 0 pure_inf_time = 0
# benchmark with several samples and take the average # benchmark with several samples and take the average
for i, data in enumerate(data_loader): for i, data in enumerate(dataloader):
torch.cuda.synchronize() torch.cuda.synchronize()
start_time = time.perf_counter() start_time = time.perf_counter()
with torch.no_grad(): with autocast(enabled=args.amp):
model(return_loss=False, rescale=True, **data) model.test_step(data)
torch.cuda.synchronize() torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time elapsed = time.perf_counter() - start_time
...@@ -83,13 +73,13 @@ def main(): ...@@ -83,13 +73,13 @@ def main():
pure_inf_time += elapsed pure_inf_time += elapsed
if (i + 1) % args.log_interval == 0: if (i + 1) % args.log_interval == 0:
fps = (i + 1 - num_warmup) / pure_inf_time fps = (i + 1 - num_warmup) / pure_inf_time
print(f'Done image [{i + 1:<3}/ {args.samples}], ' print(f'Done sample [{i + 1:<3}/ {args.samples}], '
f'fps: {fps:.1f} img / s') f'fps: {fps:.1f} sample / s')
if (i + 1) == args.samples: if (i + 1) == args.samples:
pure_inf_time += elapsed pure_inf_time += elapsed
fps = (i + 1 - num_warmup) / pure_inf_time fps = (i + 1 - num_warmup) / pure_inf_time
print(f'Overall fps: {fps:.1f} img / s') print(f'Overall fps: {fps:.1f} sample / s')
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment