Unverified Commit 98d26420 authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Feature] Spvcnn backbone (#2320)

* add cylindrical voxelization & voxel feature encoder

* add cylindrical voxelization & voxel feature encoder

* add voxel-wise label & voxelization UT

* fix vfe

* fix vfe UT

* rename voxel encoder & add more test case

* fix type hint

* temporarily refactoring mmcv's voxelize and dynamic in mmdet3d for data_preprocesser

* _forward

* del checkpoints

* add if tp

* add predict

* fix vfe init bug & fix UT

* add grid_size & move voxelization code

* fix import bug

* keep radian to follow origin

* add doc string

* fix type hint

* add minkunet voxelization and loss function

* fix data

* init train

* fix sparsetensor typehint

* rename dir

* fix data config

* fix data config

* fix batch_size & replace dynamic_scatter

* fix conflicts 2

* fix conflicts on s_70

* Alignment of the original implementation

* rename config

* add worker_init_fn_hook

* remove test_config & worker hook

* add UT

* fix polarmix UT

* init spcvnn backbone

* add seed for cr0p5

* spvcnn_init

* format

* rename SemanticKittiDataset

* add platte & fix visual bug

* add platte & fix data info bug

* fix ut

* fix ut

* fix semantic_kitti ut

* train init

* fix docstring

* fix config name

* rename layer

* fix doc string

* fix review

* remove filter data

* rename config

* rename backbone

* rename backbone 2

* refactor voxel2point

* fix coors typo

* fix ut

* fix ut

* pred in segmentor

* fix get voxel seg

* resolve comments

* rename p2v and v2p

* rename points and voxels
parent f4b0174b
model = dict(
type='MinkUNet',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
voxel_type='minkunet',
voxel_layer=dict(
max_num_points=-1,
point_cloud_range=[-100, -100, -20, 100, 100, 20],
voxel_size=[0.05, 0.05, 0.05],
max_voxels=(-1, -1)),
),
backbone=dict(
type='SPVCNNBackbone',
in_channels=4,
base_channels=32,
encoder_channels=[32, 64, 128, 256],
decoder_channels=[256, 128, 96, 96],
num_stages=4,
drop_ratio=0.3),
decode_head=dict(
type='MinkUNetHead',
channels=96,
num_classes=19,
dropout_ratio=0,
loss_decode=dict(type='mmdet.CrossEntropyLoss', avg_non_ignore=True),
ignore_index=19),
train_cfg=dict(),
test_cfg=dict())
_base_ = ['./spvcnn_w32_8xb2-15e_semantickitti.py']
model = dict(
backbone=dict(
base_channels=16,
encoder_channels=[16, 32, 64, 128],
decoder_channels=[128, 64, 48, 48]),
decode_head=dict(channels=48))
randomness = dict(seed=1588147245)
_base_ = ['./spvcnn_w32_8xb2-15e_semantickitti.py']
model = dict(
backbone=dict(
base_channels=20,
encoder_channels=[20, 40, 81, 163],
decoder_channels=[163, 81, 61, 61]),
decode_head=dict(channels=61))
_base_ = [
'../_base_/datasets/semantickitti.py', '../_base_/models/spvcnn.py',
'../_base_/default_runtime.py'
]
train_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_seg_3d=True,
seg_3d_dtype='np.int32',
seg_offset=2**16,
dataset_type='semantickitti'),
dict(type='PointSegClassMapping'),
dict(
type='GlobalRotScaleTrans',
rot_range=[0., 6.28318531],
scale_ratio_range=[0.95, 1.05],
translation_std=[0, 0, 0],
),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
]
train_dataloader = dict(
sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline)))
lr = 0.24
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='SGD', lr=lr, weight_decay=0.0001, momentum=0.9, nesterov=True))
param_scheduler = [
dict(
type='LinearLR', start_factor=0.008, by_epoch=False, begin=0, end=125),
dict(
type='CosineAnnealingLR',
begin=0,
T_max=15,
by_epoch=True,
eta_min=1e-5,
convert_to_iter_based=True)
]
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=15, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1))
randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
env_cfg = dict(cudnn_benchmark=True)
...@@ -11,10 +11,11 @@ from .nostem_regnet import NoStemRegNet ...@@ -11,10 +11,11 @@ from .nostem_regnet import NoStemRegNet
from .pointnet2_sa_msg import PointNet2SAMSG from .pointnet2_sa_msg import PointNet2SAMSG
from .pointnet2_sa_ssg import PointNet2SASSG from .pointnet2_sa_ssg import PointNet2SASSG
from .second import SECOND from .second import SECOND
from .spvcnn_backone import SPVCNNBackbone
__all__ = [ __all__ = [
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG', 'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG',
'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv', 'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv',
'MinkUNetBackbone' 'MinkUNetBackbone', 'SPVCNNBackbone'
] ]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
import torch
from mmengine.registry import MODELS
from torch import Tensor, nn
from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
from mmdet3d.utils import OptMultiConfig
from .minkunet_backbone import MinkUNetBackbone
if IS_TORCHSPARSE_AVAILABLE:
import torchsparse
import torchsparse.nn.functional as F
from torchsparse.nn.utils import get_kernel_offsets
from torchsparse.tensor import PointTensor, SparseTensor
else:
PointTensor = SparseTensor = None
@MODELS.register_module()
class SPVCNNBackbone(MinkUNetBackbone):
"""SPVCNN backbone with torchsparse backend.
More details can be found in `paper <https://arxiv.org/abs/2007.16100>`_ .
Args:
in_channels (int): Number of input voxel feature channels.
Defaults to 4.
base_channels (int): The input channels for first encoder layer.
Defaults to 32.
encoder_channels (List[int]): Convolutional channels of each encode
layer. Defaults to [32, 64, 128, 256].
decoder_channels (List[int]): Convolutional channels of each decode
layer. Defaults to [256, 128, 96, 96].
num_stages (int): Number of stages in encoder and decoder.
Defaults to 4.
drop_ratio (float): Dropout ratio of voxel features. Defaults to 0.3.
init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`]
, optional): Initialization config dict. Defaults to None.
"""
def __init__(self,
in_channels: int = 4,
base_channels: int = 32,
encoder_channels: Sequence[int] = [32, 64, 128, 256],
decoder_channels: Sequence[int] = [256, 128, 96, 96],
num_stages: int = 4,
drop_ratio: float = 0.3,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(
in_channels=in_channels,
base_channels=base_channels,
encoder_channels=encoder_channels,
decoder_channels=decoder_channels,
num_stages=num_stages,
init_cfg=init_cfg)
self.point_transforms = nn.ModuleList([
nn.Sequential(
nn.Linear(base_channels, encoder_channels[-1]),
nn.BatchNorm1d(encoder_channels[-1]), nn.ReLU(True)),
nn.Sequential(
nn.Linear(encoder_channels[-1], decoder_channels[2]),
nn.BatchNorm1d(decoder_channels[2]), nn.ReLU(True)),
nn.Sequential(
nn.Linear(decoder_channels[2], decoder_channels[4]),
nn.BatchNorm1d(decoder_channels[4]), nn.ReLU(True))
])
self.dropout = nn.Dropout(drop_ratio, True)
def forward(self, voxel_features: Tensor, coors: Tensor) -> PointTensor:
"""Forward function.
Args:
voxel_features (Tensor): Voxel features in shape (N, C).
coors (Tensor): Coordinates in shape (N, 4),
the columns in the order of (x_idx, y_idx, z_idx, batch_idx).
Returns:
PointTensor: Backbone features.
"""
voxels = SparseTensor(voxel_features, coors)
points = PointTensor(voxels.F, voxels.C.float())
voxels = self.initial_voxelize(points)
voxels = self.conv_input(voxels)
points = self.voxel_to_point(voxels, points)
voxels = self.point_to_voxel(voxels, points)
laterals = [voxels]
for encoder in self.encoder:
voxels = encoder(voxels)
laterals.append(voxels)
laterals = laterals[:-1][::-1]
points = self.voxel_to_point(voxels, points, self.point_transforms[0])
voxels = self.point_to_voxel(voxels, points)
voxels.F = self.dropout(voxels.F)
decoder_outs = []
for i, decoder in enumerate(self.decoder):
voxels = decoder[0](voxels)
voxels = torchsparse.cat((voxels, laterals[i]))
voxels = decoder[1](voxels)
decoder_outs.append(voxels)
if i == 1:
points = self.voxel_to_point(voxels, points,
self.point_transforms[1])
voxels = self.point_to_voxel(voxels, points)
voxels.F = self.dropout(voxels.F)
points = self.voxel_to_point(voxels, points, self.point_transforms[2])
return points
def initial_voxelize(self, points: PointTensor) -> SparseTensor:
"""Voxelization again based on input PointTensor.
Args:
points (PointTensor): Input points after voxelization.
Returns:
SparseTensor: New voxels.
"""
pc_hash = F.sphash(torch.floor(points.C).int())
sparse_hash = torch.unique(pc_hash)
idx_query = F.sphashquery(pc_hash, sparse_hash)
counts = F.spcount(idx_query.int(), len(sparse_hash))
inserted_coords = F.spvoxelize(
torch.floor(points.C), idx_query, counts)
inserted_coords = torch.round(inserted_coords).int()
inserted_feat = F.spvoxelize(points.F, idx_query, counts)
new_tensor = SparseTensor(inserted_feat, inserted_coords, 1)
new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords)
points.additional_features['idx_query'][1] = idx_query
points.additional_features['counts'][1] = counts
return new_tensor
def voxel_to_point(self,
voxels: SparseTensor,
points: PointTensor,
point_transform: Optional[nn.Module] = None,
nearest: bool = False) -> PointTensor:
"""Feed voxel features to points.
Args:
voxels (SparseTensor): Input voxels.
points (PointTensor): Input points.
point_transform (nn.Module, optional): Point transform module
for input point features. Defaults to None.
nearest (bool): Whether to use nearest neighbor interpolation.
Defaults to False.
Returns:
PointTensor: Points with new features.
"""
if points.idx_query is None or points.weights is None or \
points.idx_query.get(voxels.s) is None or \
points.weights.get(voxels.s) is None:
offsets = get_kernel_offsets(
2, voxels.s, 1, device=points.F.device)
old_hash = F.sphash(
torch.cat([
torch.floor(points.C[:, :3] / voxels.s[0]).int() *
voxels.s[0], points.C[:, -1].int().view(-1, 1)
], 1), offsets)
pc_hash = F.sphash(voxels.C.to(points.F.device))
idx_query = F.sphashquery(old_hash, pc_hash)
weights = F.calc_ti_weights(
points.C, idx_query,
scale=voxels.s[0]).transpose(0, 1).contiguous()
idx_query = idx_query.transpose(0, 1).contiguous()
if nearest:
weights[:, 1:] = 0.
idx_query[:, 1:] = -1
new_features = F.spdevoxelize(voxels.F, idx_query, weights)
new_tensor = PointTensor(
new_features,
points.C,
idx_query=points.idx_query,
weights=points.weights)
new_tensor.additional_features = points.additional_features
new_tensor.idx_query[voxels.s] = idx_query
new_tensor.weights[voxels.s] = weights
points.idx_query[voxels.s] = idx_query
points.weights[voxels.s] = weights
else:
new_features = F.spdevoxelize(voxels.F,
points.idx_query.get(voxels.s),
points.weights.get(voxels.s))
new_tensor = PointTensor(
new_features,
points.C,
idx_query=points.idx_query,
weights=points.weights)
new_tensor.additional_features = points.additional_features
if point_transform is not None:
new_tensor.F = new_tensor.F + point_transform(points.F)
return new_tensor
def point_to_voxel(self, voxels: SparseTensor,
points: PointTensor) -> SparseTensor:
"""Feed point features to voxels.
Args:
voxels (SparseTensor): Input voxels.
points (PointTensor): Input points.
Returns:
SparseTensor: Voxels with new features.
"""
if points.additional_features is None or \
points.additional_features.get('idx_query') is None or \
points.additional_features['idx_query'].get(voxels.s) is None:
pc_hash = F.sphash(
torch.cat([
torch.floor(points.C[:, :3] / voxels.s[0]).int() *
voxels.s[0], points.C[:, -1].int().view(-1, 1)
], 1))
sparse_hash = F.sphash(voxels.C)
idx_query = F.sphashquery(pc_hash, sparse_hash)
counts = F.spcount(idx_query.int(), voxels.C.shape[0])
points.additional_features['idx_query'][voxels.s] = idx_query
points.additional_features['counts'][voxels.s] = counts
else:
idx_query = points.additional_features['idx_query'][voxels.s]
counts = points.additional_features['counts'][voxels.s]
inserted_features = F.spvoxelize(points.F, idx_query, counts)
new_tensor = SparseTensor(inserted_features, voxels.C, voxels.s)
new_tensor.cmaps = voxels.cmaps
new_tensor.kmaps = voxels.kmaps
return new_tensor
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn.functional as F
from mmdet3d.registry import MODELS
def test_spvcnn_backbone():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
try:
import torchsparse # noqa: F401
except ImportError:
pytest.skip('test requires Torchsparse installation')
coordinates, features = [], []
for i in range(2):
c = torch.randint(0, 10, (100, 3)).int()
c = F.pad(c, (0, 1), mode='constant', value=i)
coordinates.append(c)
f = torch.rand(100, 4)
features.append(f)
features = torch.cat(features, dim=0).cuda()
coordinates = torch.cat(coordinates, dim=0).cuda()
cfg = dict(type='SPVCNNBackbone')
self = MODELS.build(cfg).cuda()
self.init_weights()
y = self(features, coordinates)
assert y.F.shape == torch.Size([200, 96])
assert y.C.shape == torch.Size([200, 4])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment