Unverified Commit dfcf5428 authored by xizaoqu's avatar xizaoqu Committed by GitHub
Browse files

[Feature] Cylinder3d segmentor (#2344)

* update

* add cylinder3d_backbone

* add test segmentor

* add cfg

* add test backbone

* rename test cylinder3d backbone

* midway

* update, pass validation

* fix test

* update cfg
parent afa4479c
grid_shape = [480, 360, 32]
model = dict(
type='Cylinder3D',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
voxel_type='cylindrical',
voxel_layer=dict(
grid_shape=grid_shape,
point_cloud_range=[0, -3.14159265359, -4, 50, 3.14159265359, 2],
max_num_points=-1,
max_voxels=-1,
),
),
voxel_encoder=dict(
type='SegVFE',
feat_channels=[64, 128, 256, 256],
in_channels=6,
with_voxel_center=True,
feat_compression=16,
return_point_feats=False),
backbone=dict(
type='Asymm3DSpconv',
grid_size=grid_shape,
input_channels=16,
base_channels=32,
norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.1)),
decode_head=dict(
type='Cylinder3DHead',
channels=128,
num_classes=20,
loss_ce=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=False,
class_weight=None,
loss_weight=1.0),
loss_lovasz=dict(type='LovaszLoss', loss_weight=1.0, reduction='none'),
),
train_cfg=None,
test_cfg=dict(mode='whole'),
)
_base_ = [
'../_base_/datasets/semantickitti.py', '../_base_/models/cylinder3d.py',
'../_base_/default_runtime.py'
]
# optimizer
# This schedule is mainly used by models on nuScenes dataset
lr = 0.001
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.01))
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=36, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
# learning rate
param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0,
end=1000),
dict(
type='MultiStepLR',
begin=0,
end=36,
by_epoch=True,
milestones=[30],
gamma=0.1)
]
# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (4 samples per GPU).
# auto_scale_lr = dict(enable=False, base_batch_size=32)
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5))
......@@ -255,8 +255,8 @@ class Seg3DDataset(BaseDataset):
osp.join(
self.data_prefix.get('pts', ''),
info['lidar_points']['lidar_path'])
info['num_pts_feats'] = info['lidar_points']['num_pts_feats']
if 'num_pts_feats' in info['lidar_points']:
info['num_pts_feats'] = info['lidar_points']['num_pts_feats']
info['lidar_path'] = info['lidar_points']['lidar_path']
if self.modality['use_camera']:
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt
from .cylinder3d import Asymm3DSpconv
from .dgcnn import DGCNNBackbone
from .dla import DLANet
from .mink_resnet import MinkResNet
......@@ -13,5 +14,5 @@ from .second import SECOND
__all__ = [
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG',
'MultiBackbone', 'DLANet', 'MinkResNet'
'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv'
]
# Copyright (c) OpenMMLab. All rights reserved.
r"""Modified from Cylinder3D.
Please refer to `Cylinder3D github page
<https://github.com/xinge008/Cylinder3D>`_ for details
"""
from typing import List
import numpy as np
import torch
from mmcv.ops import SparseConvTensor
from mmengine.model import BaseModule
from mmdet3d.models.layers.sparse_block import (AsymmeDownBlock, AsymmeUpBlock,
AsymmResBlock, DDCMBlock)
from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType
@MODELS.register_module()
class Asymm3DSpconv(BaseModule):
"""Asymmetrical 3D convolution networks.
Args:
grid_size (int): Size of voxel grids.
input_channels (int): Input channels of the block.
base_channels (int): Initial size of feature channels before
feeding into Encoder-Decoder structure. Defaults to 16.
backbone_depth (int): The depth of backbone. The backbone contains
downblocks and upblocks with the number of backbone_depth.
height_pooing (List[bool]): List indicating which downblocks perform
height pooling.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01)).
init_cfg (dict, optional): Initialization config.
Defaults to None.
"""
def __init__(self,
grid_size: int,
input_channels: int,
base_channels: int = 16,
backbone_depth: int = 4,
height_pooing: List[bool] = [True, True, False, False],
norm_cfg: ConfigType = dict(
type='BN1d', eps=1e-3, momentum=0.01),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.grid_size = grid_size
self.backbone_depth = backbone_depth
self.down_context = AsymmResBlock(
input_channels, base_channels, indice_key='pre', norm_cfg=norm_cfg)
self.down_block_list = torch.nn.ModuleList()
self.up_block_list = torch.nn.ModuleList()
for i in range(self.backbone_depth):
self.down_block_list.append(
AsymmeDownBlock(
2**i * base_channels,
2**(i + 1) * base_channels,
height_pooling=height_pooing[i],
indice_key='down' + str(i),
norm_cfg=norm_cfg))
if i == self.backbone_depth - 1:
self.up_block_list.append(
AsymmeUpBlock(
2**(i + 1) * base_channels,
2**(i + 1) * base_channels,
up_key='down' + str(i),
indice_key='up' + str(self.backbone_depth - 1 - i),
norm_cfg=norm_cfg))
else:
self.up_block_list.append(
AsymmeUpBlock(
2**(i + 2) * base_channels,
2**(i + 1) * base_channels,
up_key='down' + str(i),
indice_key='up' + str(self.backbone_depth - 1 - i),
norm_cfg=norm_cfg))
self.ddcm = DDCMBlock(
2 * base_channels,
2 * base_channels,
indice_key='ddcm',
norm_cfg=norm_cfg)
def forward(self, voxel_features: torch.Tensor, coors: torch.Tensor,
batch_size: int) -> SparseConvTensor:
"""Forward pass."""
coors = coors.int()
ret = SparseConvTensor(voxel_features, coors, np.array(self.grid_size),
batch_size)
ret = self.down_context(ret)
down_skip_list = []
down_pool = ret
for i in range(self.backbone_depth):
down_pool, down_skip = self.down_block_list[i](down_pool)
down_skip_list.append(down_skip)
up = down_pool
for i in range(self.backbone_depth - 1, -1, -1):
up = self.up_block_list[i](up, down_skip_list[i])
ddcm = self.ddcm(up)
ddcm.features = torch.cat((ddcm.features, up.features), 1)
return ddcm
......@@ -39,7 +39,7 @@ class Cylinder3DHead(Base3DDecodeHead):
conv_seg_kernel_size (int): The kernel size used in conv_seg.
Defaults to 3.
ignore_index (int): The label index to be ignored. When using masked
BCE loss, ignore_index should be set to None. Defaults to 0.
BCE loss, ignore_index should be set to None. Defaults to 19.
init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`],
optional): Initialization config dict. Defaults to None.
"""
......@@ -59,7 +59,7 @@ class Cylinder3DHead(Base3DDecodeHead):
loss_lovasz: ConfigType = dict(
type='LovaszLoss', loss_weight=1.0),
conv_seg_kernel_size: int = 3,
ignore_index: int = 0,
ignore_index: int = 19,
init_cfg: OptMultiConfig = None) -> None:
super(Cylinder3DHead, self).__init__(
channels=channels,
......@@ -116,8 +116,6 @@ class Cylinder3DHead(Base3DDecodeHead):
loss = dict()
loss['loss_ce'] = self.loss_ce(
seg_logit_feat, seg_label, ignore_index=self.ignore_index)
seg_logit_feat = seg_logit_feat.permute(1, 0)[None, :, :,
None] # pseudo BCHW
loss['loss_lovasz'] = self.loss_lovasz(
seg_logit_feat, seg_label, ignore_index=self.ignore_index)
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union
from typing import Optional, Tuple, Union
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
from torch import nn
from mmdet3d.utils import OptConfigType
from mmdet3d.utils import ConfigType, OptConfigType
from .spconv import IS_SPCONV2_AVAILABLE
if IS_SPCONV2_AVAILABLE:
from spconv.pytorch import SparseConvTensor, SparseModule, SparseSequential
else:
from mmcv.ops import SparseConvTensor, SparseModule, SparseSequential
from mmcv.ops import (SparseConvTensor, SparseModule, SparseSequential,
SparseConv3d, SparseInverseConv3d, SubMConv3d)
from mmengine.model import BaseModule
def replace_feature(out: SparseConvTensor,
......@@ -207,3 +210,374 @@ def make_sparse_convmodule(
layers = SparseSequential(*layers)
return layers
# The following module only supports spconv_v1
class AsymmResBlock(BaseModule):
"""Asymmetrical Residual Block.
Args:
in_channels (int): Input channels of the block.
out_channels (int): Output channels of the block.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for
normalization layer.
act_cfg (:obj:`ConfigDict` or dict): Config dict of activation layers.
Defaults to dict(type='LeakyReLU').
indice_key (str, optional): Name of indice tables. Defaults to None.
"""
def __init__(self,
in_channels: int,
out_channels: int,
norm_cfg: ConfigType,
act_cfg: ConfigType = dict(type='LeakyReLU'),
indice_key: Optional[str] = None):
super().__init__()
self.conv0_0 = SubMConv3d(
in_channels,
out_channels,
kernel_size=(1, 3, 3),
padding=1,
bias=False,
indice_key=indice_key + 'bef')
self.act0_0 = build_activation_layer(act_cfg)
self.bn0_0 = build_norm_layer(norm_cfg, out_channels)[1]
self.conv0_1 = SubMConv3d(
out_channels,
out_channels,
kernel_size=(3, 1, 3),
padding=1,
bias=False,
indice_key=indice_key + 'bef')
self.act0_1 = build_activation_layer(act_cfg)
self.bn0_1 = build_norm_layer(norm_cfg, out_channels)[1]
self.conv1_0 = SubMConv3d(
in_channels,
out_channels,
kernel_size=(3, 1, 3),
padding=1,
bias=False,
indice_key=indice_key + 'bef')
self.act1_0 = build_activation_layer(act_cfg)
self.bn1_0 = build_norm_layer(norm_cfg, out_channels)[1]
self.conv1_1 = SubMConv3d(
out_channels,
out_channels,
kernel_size=(1, 3, 3),
padding=1,
bias=False,
indice_key=indice_key + 'bef')
self.act1_1 = build_activation_layer(act_cfg)
self.bn1_1 = build_norm_layer(norm_cfg, out_channels)[1]
def forward(self, x: SparseConvTensor) -> SparseConvTensor:
"""Forward pass."""
shortcut = self.conv0_0(x)
shortcut.features = self.act0_0(shortcut.features)
shortcut.features = self.bn0_0(shortcut.features)
shortcut = self.conv0_1(shortcut)
shortcut.features = self.act0_1(shortcut.features)
shortcut.features = self.bn0_1(shortcut.features)
res = self.conv1_0(x)
res.features = self.act1_0(res.features)
res.features = self.bn1_0(res.features)
res = self.conv1_1(res)
res.features = self.act1_1(res.features)
res.features = self.bn1_1(res.features)
res.features = res.features + shortcut.features
return res
class AsymmeDownBlock(BaseModule):
"""Asymmetrical DownSample Block.
Args:
in_channels (int): Input channels of the block.
out_channels (int): Output channels of the block.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for
normalization layer.
act_cfg (:obj:`ConfigDict` or dict): Config dict of activation layers.
Defaults to dict(type='LeakyReLU').
pooling (bool): Whether pooling features at the end of
block. Defaults: True.
height_pooling (bool): Whether pooling features at
the height dimension. Defaults: False.
indice_key (str, optional): Name of indice tables. Defaults to None.
"""
def __init__(self,
in_channels: int,
out_channels: int,
norm_cfg: ConfigType,
act_cfg: ConfigType = dict(type='LeakyReLU'),
pooling: bool = True,
height_pooling: bool = False,
indice_key: Optional[str] = None):
super().__init__()
self.pooling = pooling
self.conv0_0 = SubMConv3d(
in_channels,
out_channels,
kernel_size=(3, 1, 3),
padding=1,
bias=False,
indice_key=indice_key + 'bef')
self.act0_0 = build_activation_layer(act_cfg)
self.bn0_0 = build_norm_layer(norm_cfg, out_channels)[1]
self.conv0_1 = SubMConv3d(
out_channels,
out_channels,
kernel_size=(1, 3, 3),
padding=1,
bias=False,
indice_key=indice_key + 'bef')
self.act0_1 = build_activation_layer(act_cfg)
self.bn0_1 = build_norm_layer(norm_cfg, out_channels)[1]
self.conv1_0 = SubMConv3d(
in_channels,
out_channels,
kernel_size=(1, 3, 3),
padding=1,
bias=False,
indice_key=indice_key + 'bef')
self.act1_0 = build_activation_layer(act_cfg)
self.bn1_0 = build_norm_layer(norm_cfg, out_channels)[1]
self.conv1_1 = SubMConv3d(
out_channels,
out_channels,
kernel_size=(3, 1, 3),
padding=1,
bias=False,
indice_key=indice_key + 'bef')
self.act1_1 = build_activation_layer(act_cfg)
self.bn1_1 = build_norm_layer(norm_cfg, out_channels)[1]
if pooling:
if height_pooling:
self.pool = SparseConv3d(
out_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
indice_key=indice_key,
bias=False)
else:
self.pool = SparseConv3d(
out_channels,
out_channels,
kernel_size=3,
stride=(2, 2, 1),
padding=1,
indice_key=indice_key,
bias=False)
def forward(self, x: SparseConvTensor) -> SparseConvTensor:
"""Forward pass."""
shortcut = self.conv0_0(x)
shortcut.features = self.act0_0(shortcut.features)
shortcut.features = self.bn0_0(shortcut.features)
shortcut = self.conv0_1(shortcut)
shortcut.features = self.act0_1(shortcut.features)
shortcut.features = self.bn0_1(shortcut.features)
res = self.conv1_0(x)
res.features = self.act1_0(res.features)
res.features = self.bn1_0(res.features)
res = self.conv1_1(res)
res.features = self.act1_1(res.features)
res.features = self.bn1_1(res.features)
res.features = res.features + shortcut.features
if self.pooling:
pooled_res = self.pool(res)
return pooled_res, res
else:
return res
class AsymmeUpBlock(BaseModule):
"""Asymmetrical UpSample Block.
Args:
in_channels (int): Input channels of the block.
out_channels (int): Output channels of the block.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for
normalization layer.
act_cfg (:obj:`ConfigDict` or dict): Config dict of activation layers.
Defaults to dict(type='LeakyReLU').
indice_key (str, optional): Name of indice tables. Defaults to None.
up_key (str, optional): Name of indice tables used in
SparseInverseConv3d. Defaults to None.
"""
def __init__(self,
in_channels: int,
out_channels: int,
norm_cfg: ConfigType,
act_cfg: ConfigType = dict(type='LeakyReLU'),
indice_key: Optional[str] = None,
up_key: Optional[str] = None):
super().__init__()
self.trans_conv = SubMConv3d(
in_channels,
out_channels,
kernel_size=(3, 3, 3),
padding=1,
bias=False,
indice_key=indice_key + 'new_up')
self.trans_act = build_activation_layer(act_cfg)
self.trans_bn = build_norm_layer(norm_cfg, out_channels)[1]
self.conv1 = SubMConv3d(
out_channels,
out_channels,
kernel_size=(1, 3, 3),
padding=1,
bias=False,
indice_key=indice_key)
self.act1 = build_activation_layer(act_cfg)
self.bn1 = build_norm_layer(norm_cfg, out_channels)[1]
self.conv2 = SubMConv3d(
out_channels,
out_channels,
kernel_size=(3, 1, 3),
padding=1,
bias=False,
indice_key=indice_key)
self.act2 = build_activation_layer(act_cfg)
self.bn2 = build_norm_layer(norm_cfg, out_channels)[1]
self.conv3 = SubMConv3d(
out_channels,
out_channels,
kernel_size=(3, 3, 3),
padding=1,
bias=False,
indice_key=indice_key)
self.act3 = build_activation_layer(act_cfg)
self.bn3 = build_norm_layer(norm_cfg, out_channels)[1]
self.up_subm = SparseInverseConv3d(
out_channels,
out_channels,
kernel_size=3,
indice_key=up_key,
bias=False)
def forward(self, x: SparseConvTensor,
skip: SparseConvTensor) -> SparseConvTensor:
"""Forward pass."""
x_trans = self.trans_conv(x)
x_trans.features = self.trans_act(x_trans.features)
x_trans.features = self.trans_bn(x_trans.features)
# upsample
up = self.up_subm(x_trans)
up.features = up.features + skip.features
up = self.conv1(up)
up.features = self.act1(up.features)
up.features = self.bn1(up.features)
up = self.conv2(up)
up.features = self.act2(up.features)
up.features = self.bn2(up.features)
up = self.conv3(up)
up.features = self.act3(up.features)
up.features = self.bn3(up.features)
return up
class DDCMBlock(BaseModule):
"""Dimension-Decomposition based Context Modeling.
Args:
in_channels (int): Input channels of the block.
out_channels (int): Output channels of the block.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for
normalization layer.
act_cfg (:obj:`ConfigDict` or dict): Config dict of activation layers.
Defaults to dict(type='Sigmoid').
indice_key (str, optional): Name of indice tables. Defaults to None.
"""
def __init__(self,
in_channels: int,
out_channels: int,
norm_cfg: ConfigType,
act_cfg: ConfigType = dict(type='Sigmoid'),
indice_key: Optional[str] = None):
super().__init__()
self.conv1 = SubMConv3d(
in_channels,
out_channels,
kernel_size=(3, 1, 1),
padding=1,
bias=False,
indice_key=indice_key)
self.bn1 = build_norm_layer(norm_cfg, out_channels)[1]
self.act1 = build_activation_layer(act_cfg)
self.conv2 = SubMConv3d(
in_channels,
out_channels,
kernel_size=(1, 3, 1),
padding=1,
bias=False,
indice_key=indice_key)
self.bn2 = build_norm_layer(norm_cfg, out_channels)[1]
self.act2 = build_activation_layer(act_cfg)
self.conv3 = SubMConv3d(
in_channels,
out_channels,
kernel_size=(1, 1, 3),
padding=1,
bias=False,
indice_key=indice_key)
self.bn3 = build_norm_layer(norm_cfg, out_channels)[1]
self.act3 = build_activation_layer(act_cfg)
def forward(self, x: SparseConvTensor) -> SparseConvTensor:
"""Forward pass."""
shortcut = self.conv1(x)
shortcut.features = self.bn1(shortcut.features)
shortcut.features = self.act1(shortcut.features)
shortcut2 = self.conv2(x)
shortcut2.features = self.bn2(shortcut2.features)
shortcut2.features = self.act2(shortcut2.features)
shortcut3 = self.conv3(x)
shortcut3.features = self.bn3(shortcut3.features)
shortcut3.features = self.act3(shortcut3.features)
shortcut.features = shortcut.features + \
shortcut2.features + shortcut3.features
shortcut.features = shortcut.features * x.features
return shortcut
# Copyright (c) OpenMMLab. All rights reserved.
from .base import Base3DSegmentor
from .cylinder3d import Cylinder3D
from .encoder_decoder import EncoderDecoder3D
__all__ = ['Base3DSegmentor', 'EncoderDecoder3D']
__all__ = ['Base3DSegmentor', 'EncoderDecoder3D', 'Cylinder3D']
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
from torch import Tensor
from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig
from ...structures.det3d_data_sample import SampleList
from .encoder_decoder import EncoderDecoder3D
@MODELS.register_module()
class Cylinder3D(EncoderDecoder3D):
"""`Cylindrical and Asymmetrical 3D Convolution Networks for LiDAR
Segmentation.
<https://arxiv.org/abs/2011.10033>`_.
Args:
voxel_encoder (dict or :obj:`ConfigDict`): The config for the
points2voxel encoder of segmentor.
backbone (dict or :obj:`ConfigDict`): The config for the backnone of
segmentor.
decode_head (dict or :obj:`ConfigDict`): The config for the decode
head of segmentor.
neck (dict or :obj:`ConfigDict`, optional): The config for the neck of
segmentor. Defaults to None.
auxiliary_head (dict or :obj:`ConfigDict` or List[dict or
:obj:`ConfigDict`], optional): The config for the auxiliary head of
segmentor. Defaults to None.
loss_regularization (dict or :obj:`ConfigDict` or List[dict or
:obj:`ConfigDict`], optional): The config for the regularization
loass. Defaults to None.
train_cfg (dict or :obj:`ConfigDict`, optional): The config for
training. Defaults to None.
test_cfg (dict or :obj:`ConfigDict`, optional): The config for testing.
Defaults to None.
data_preprocessor (dict or :obj:`ConfigDict`, optional): The
pre-process config of :class:`BaseDataPreprocessor`.
Defaults to None.
init_cfg (dict or :obj:`ConfigDict` or List[dict or :obj:`ConfigDict`],
optional): The weight initialized config for :class:`BaseModule`.
Defaults to None.
"""
def __init__(self,
voxel_encoder: ConfigType,
backbone: ConfigType,
decode_head: ConfigType,
neck: OptConfigType = None,
auxiliary_head: OptConfigType = None,
loss_regularization: OptConfigType = None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None) -> None:
super(Cylinder3D, self).__init__(
backbone=backbone,
decode_head=decode_head,
neck=neck,
auxiliary_head=auxiliary_head,
loss_regularization=loss_regularization,
train_cfg=train_cfg,
test_cfg=test_cfg,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
self.voxel_encoder = MODELS.build(voxel_encoder)
def extract_feat(self, batch_inputs: dict) -> Tensor:
"""Extract features from points."""
encoded_feats = self.voxel_encoder(batch_inputs['voxels']['voxels'],
batch_inputs['voxels']['coors'])
batch_inputs['voxels']['voxel_coors'] = encoded_feats[1]
x = self.backbone(encoded_feats[0], encoded_feats[1],
len(batch_inputs['points']))
if self.with_neck:
x = self.neck(x)
return x
def loss(self, batch_inputs_dict: dict,
batch_data_samples: SampleList) -> Dict[str, Tensor]:
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs_dict (dict): Input sample dict which
includes 'points' and 'imgs' keys.
- points (List[Tensor]): Point cloud of each sample.
- imgs (Tensor, optional): Image tensor has shape (B, C, H, W).
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`.
Returns:
Dict[str, Tensor]: A dictionary of loss components.
"""
# extract features using backbone
x = self.extract_feat(batch_inputs_dict)
losses = dict()
loss_decode = self._decode_head_forward_train(x, batch_data_samples)
losses.update(loss_decode)
return losses
def predict(self,
batch_inputs_dict: dict,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
"""Simple test with single scene.
Args:
batch_inputs_dict (dict): Input sample dict which includes 'points'
and 'imgs' keys.
- points (List[Tensor]): Point cloud of each sample.
- imgs (Tensor, optional): Image tensor has shape (B, C, H, W).
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`.
rescale (bool): Whether transform to original number of points.
Will be used for voxelization based segmentors.
Defaults to True.
Returns:
List[:obj:`Det3DDataSample`]: Segmentation results of the input
points. Each Det3DDataSample usually contains:
- ``pred_pts_seg`` (PixelData): Prediction of 3D semantic
segmentation.
"""
# 3D segmentation requires per-point prediction, so it's impossible
# to use down-sampling to get a batch of scenes with same num_points
# therefore, we only support testing one scene every time
x = self.extract_feat(batch_inputs_dict)
seg_pred_list = self.decode_head.predict(x, batch_inputs_dict,
batch_data_samples)
for i in range(len(seg_pred_list)):
seg_pred_list[i] = seg_pred_list[i].argmax(1).cpu()
return self.postprocess_result(seg_pred_list, batch_data_samples)
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmdet3d.registry import MODELS
def test_cylinder3d():
if not torch.cuda.is_available():
pytest.skip()
cfg = dict(
type='Asymm3DSpconv',
grid_size=[48, 32, 4],
input_channels=16,
base_channels=32,
norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.1))
self = MODELS.build(cfg)
self.cuda()
batch_size = 1
coorx = torch.randint(0, 48, (50, 1))
coory = torch.randint(0, 36, (50, 1))
coorz = torch.randint(0, 4, (50, 1))
coorbatch = torch.zeros(50, 1)
coors = torch.cat([coorbatch, coorx, coory, coorz], dim=1).cuda()
voxel_features = torch.rand(50, 16).cuda()
# test forward
feature = self(voxel_features, coors, batch_size)
assert feature.features.shape == (50, 128)
assert feature.indices.data.shape == (50, 4)
import unittest
import torch
from mmengine import DefaultScope
from mmdet3d.registry import MODELS
from mmdet3d.testing import (create_detector_inputs, get_detector_cfg,
setup_seed)
class TestCylinder3D(unittest.TestCase):
def test_cylinder3d(self):
import mmdet3d.models
assert hasattr(mmdet3d.models, 'Cylinder3D')
DefaultScope.get_instance('test_cylinder3d', scope_name='mmdet3d')
setup_seed(0)
cylinder3d_cfg = get_detector_cfg(
'cylinder3d/cylinder3d_4xb2_3x_semantickitti.py')
cylinder3d_cfg.decode_head['ignore_index'] = 1
model = MODELS.build(cylinder3d_cfg)
num_gt_instance = 3
packed_inputs = create_detector_inputs(
num_gt_instance=num_gt_instance,
num_classes=1,
with_pts_semantic_mask=True)
if torch.cuda.is_available():
model = model.cuda()
# test simple_test
with torch.no_grad():
data = model.data_preprocessor(packed_inputs, True)
torch.cuda.empty_cache()
results = model.forward(**data, mode='predict')
self.assertEqual(len(results), 1)
self.assertIn('pts_semantic_mask', results[0].pred_pts_seg)
losses = model.forward(**data, mode='loss')
self.assertGreater(losses['decode.loss_ce'], 0)
self.assertGreater(losses['decode.loss_lovasz'], 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment