Commit aa544d4a authored by VVsssssk's avatar VVsssssk Committed by ZwwWayne
Browse files

[Features] Support PV_RCNN modules (#1957)

* add pvrcnn module code

* add voxelsa

* fix

* fix comments

* fix comments

* fix comments

* add stack sa

* fix

* fix comments

* fix comments

* fix

* add ut

* fix comments
parent 1fd71531
_base_ = [
'../_base_/datasets/kitti-3d-3class.py',
'../_base_/schedules/cyclic-40e.py', '../_base_/default_runtime.py'
]
voxel_size = [0.05, 0.05, 0.1]
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
data_root = 'data/kitti/'
class_names = ['Pedestrian', 'Cyclist', 'Car']
metainfo = dict(CLASSES=class_names)
db_sampler = dict(
data_root=data_root,
info_path=data_root + 'kitti_dbinfos_train.pkl',
rate=1.0,
prepare=dict(
filter_by_difficulty=[-1],
filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)),
classes=class_names,
sample_groups=dict(Car=15, Pedestrian=10, Cyclist=10),
points_loader=dict(
type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4))
train_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler, use_ground_plane=True),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05]),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'),
dict(
type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
model = dict(
type='PointVoxelRCNN',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
voxel_layer=dict(
max_num_points=5, # max_points_per_voxel
point_cloud_range=point_cloud_range,
voxel_size=voxel_size,
max_voxels=(16000, 40000))),
voxel_encoder=dict(type='HardSimpleVFE'),
middle_encoder=dict(
type='SparseEncoder',
in_channels=4,
sparse_shape=[41, 1600, 1408],
order=('conv', 'norm', 'act'),
encoder_paddings=((0, 0, 0), ((1, 1, 1), 0, 0), ((1, 1, 1), 0, 0),
((0, 1, 1), 0, 0)),
return_middle_feats=True),
points_encoder=dict(
type='VoxelSetAbstraction',
num_keypoints=2048,
fused_out_channel=128,
voxel_size=voxel_size,
point_cloud_range=point_cloud_range,
voxel_sa_cfgs_list=[
dict(
type='StackedSAModuleMSG',
in_channels=16,
scale_factor=1,
radius=(0.4, 0.8),
sample_nums=(16, 16),
mlp_channels=((16, 16), (16, 16)),
use_xyz=True),
dict(
type='StackedSAModuleMSG',
in_channels=32,
scale_factor=2,
radius=(0.8, 1.2),
sample_nums=(16, 32),
mlp_channels=((32, 32), (32, 32)),
use_xyz=True),
dict(
type='StackedSAModuleMSG',
in_channels=64,
scale_factor=4,
radius=(1.2, 2.4),
sample_nums=(16, 32),
mlp_channels=((64, 64), (64, 64)),
use_xyz=True),
dict(
type='StackedSAModuleMSG',
in_channels=64,
scale_factor=8,
radius=(2.4, 4.8),
sample_nums=(16, 32),
mlp_channels=((64, 64), (64, 64)),
use_xyz=True)
],
rawpoints_sa_cfgs=dict(
type='StackedSAModuleMSG',
in_channels=1,
radius=(0.4, 0.8),
sample_nums=(16, 16),
mlp_channels=((16, 16), (16, 16)),
use_xyz=True),
bev_feat_channel=256,
bev_scale_factor=8),
backbone=dict(
type='SECOND',
in_channels=256,
layer_nums=[5, 5],
layer_strides=[1, 2],
out_channels=[128, 256]),
neck=dict(
type='SECONDFPN',
in_channels=[128, 256],
upsample_strides=[1, 2],
out_channels=[256, 256]),
rpn_head=dict(
type='PartA2RPNHead',
num_classes=3,
in_channels=512,
feat_channels=512,
use_direction_classifier=True,
dir_offset=0.78539,
anchor_generator=dict(
type='Anchor3DRangeGenerator',
ranges=[[0, -40.0, -0.6, 70.4, 40.0, -0.6],
[0, -40.0, -0.6, 70.4, 40.0, -0.6],
[0, -40.0, -1.78, 70.4, 40.0, -1.78]],
sizes=[[0.8, 0.6, 1.73], [1.76, 0.6, 1.73], [3.9, 1.6, 1.56]],
rotations=[0, 1.57],
reshape_out=False),
diff_rad_by_sin=True,
assigner_per_size=True,
assign_per_class=True,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_cls=dict(
type='mmdet.FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(
type='mmdet.SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_dir=dict(
type='mmdet.CrossEntropyLoss', use_sigmoid=False,
loss_weight=0.2)),
roi_head=dict(
type='PVRCNNRoiHead',
num_classes=3,
semantic_head=dict(
type='ForegroundSegmentationHead',
in_channels=640,
extra_width=0.1,
loss_seg=dict(
type='mmdet.FocalLoss',
use_sigmoid=True,
reduction='sum',
gamma=2.0,
alpha=0.25,
activated=True,
loss_weight=1.0)),
bbox_roi_extractor=dict(
type='Batch3DRoIGridExtractor',
grid_size=6,
roi_layer=dict(
type='StackedSAModuleMSG',
in_channels=128,
radius=(0.8, 1.6),
sample_nums=(16, 16),
mlp_channels=((64, 64), (64, 64)),
use_xyz=True,
pool_mod='max'),
),
bbox_head=dict(
type='PVRCNNBBoxHead',
in_channels=128,
grid_size=6,
num_classes=3,
class_agnostic=True,
shared_fc_channels=(256, 256),
reg_channels=(256, 256),
cls_channels=(256, 256),
dropout_ratio=0.3,
with_corner_loss=True,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_bbox=dict(
type='mmdet.SmoothL1Loss',
beta=1.0 / 9.0,
reduction='sum',
loss_weight=1.0),
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=[
dict( # for Pedestrian
type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.5,
neg_iou_thr=0.35,
min_pos_iou=0.35,
ignore_iof_thr=-1),
dict( # for Cyclist
type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.5,
neg_iou_thr=0.35,
min_pos_iou=0.35,
ignore_iof_thr=-1),
dict( # for Car
type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.6,
neg_iou_thr=0.45,
min_pos_iou=0.45,
ignore_iof_thr=-1)
],
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=9000,
nms_post=512,
max_num=512,
nms_thr=0.8,
score_thr=0,
use_rotate_nms=True),
rcnn=dict(
assigner=[
dict( # for Pedestrian
type='Max3DIoUAssigner',
iou_calculator=dict(
type='BboxOverlaps3D', coordinate='lidar'),
pos_iou_thr=0.55,
neg_iou_thr=0.55,
min_pos_iou=0.55,
ignore_iof_thr=-1),
dict( # for Cyclist
type='Max3DIoUAssigner',
iou_calculator=dict(
type='BboxOverlaps3D', coordinate='lidar'),
pos_iou_thr=0.55,
neg_iou_thr=0.55,
min_pos_iou=0.55,
ignore_iof_thr=-1),
dict( # for Car
type='Max3DIoUAssigner',
iou_calculator=dict(
type='BboxOverlaps3D', coordinate='lidar'),
pos_iou_thr=0.55,
neg_iou_thr=0.55,
min_pos_iou=0.55,
ignore_iof_thr=-1)
],
sampler=dict(
type='IoUNegPiecewiseSampler',
num=128,
pos_fraction=0.5,
neg_piece_fractions=[0.8, 0.2],
neg_iou_piece_thrs=[0.55, 0.1],
neg_pos_ub=-1,
add_gt_as_proposals=False,
return_iou=True),
cls_pos_thr=0.75,
cls_neg_thr=0.25)),
test_cfg=dict(
rpn=dict(
nms_pre=1024,
nms_post=100,
max_num=100,
nms_thr=0.7,
score_thr=0,
use_rotate_nms=True),
rcnn=dict(
use_rotate_nms=True,
use_raw_score=True,
nms_thr=0.1,
score_thr=0.1)))
train_dataloader = dict(
batch_size=2,
num_workers=2,
dataset=dict(dataset=dict(pipeline=train_pipeline, metainfo=metainfo)))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo))
eval_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo))
lr = 0.001
optim_wrapper = dict(optimizer=dict(lr=lr))
param_scheduler = [
# learning rate scheduler
# During the first 16 epochs, learning rate increases from 0 to lr * 10
# during the next 24 epochs, learning rate decreases from lr * 10 to
# lr * 1e-4
dict(
type='CosineAnnealingLR',
T_max=15,
eta_min=lr * 10,
begin=0,
end=15,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=25,
eta_min=lr * 1e-4,
begin=15,
end=40,
by_epoch=True,
convert_to_iter_based=True),
# momentum scheduler
# During the first 16 epochs, momentum increases from 0 to 0.85 / 0.95
# during the next 24 epochs, momentum increases from 0.85 / 0.95 to 1
dict(
type='CosineAnnealingMomentum',
T_max=15,
eta_min=0.85 / 0.95,
begin=0,
end=15,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingMomentum',
T_max=25,
eta_min=1,
begin=15,
end=40,
by_epoch=True,
convert_to_iter_based=True)
]
...@@ -67,7 +67,6 @@ def init_model(config: Union[str, Path, Config], ...@@ -67,7 +67,6 @@ def init_model(config: Union[str, Path, Config],
if checkpoint is not None: if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
dataset_meta = checkpoint['meta'].get('dataset_meta', None) dataset_meta = checkpoint['meta'].get('dataset_meta', None)
# save the dataset_meta in the model for convenience # save the dataset_meta in the model for convenience
if 'dataset_meta' in checkpoint.get('meta', {}): if 'dataset_meta' in checkpoint.get('meta', {}):
......
...@@ -14,6 +14,7 @@ from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN ...@@ -14,6 +14,7 @@ from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN
from .mvx_two_stage import MVXTwoStageDetector from .mvx_two_stage import MVXTwoStageDetector
from .parta2 import PartA2 from .parta2 import PartA2
from .point_rcnn import PointRCNN from .point_rcnn import PointRCNN
from .pv_rcnn import PointVoxelRCNN
from .sassd import SASSD from .sassd import SASSD
from .single_stage_mono3d import SingleStageMono3DDetector from .single_stage_mono3d import SingleStageMono3DDetector
from .smoke_mono3d import SMOKEMono3D from .smoke_mono3d import SMOKEMono3D
...@@ -26,5 +27,6 @@ __all__ = [ ...@@ -26,5 +27,6 @@ __all__ = [
'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet', 'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet',
'CenterPoint', 'SSD3DNet', 'ImVoteNet', 'SingleStageMono3DDetector', 'CenterPoint', 'SSD3DNet', 'ImVoteNet', 'SingleStageMono3DDetector',
'FCOSMono3D', 'ImVoxelNet', 'GroupFree3DNet', 'PointRCNN', 'SMOKEMono3D', 'FCOSMono3D', 'ImVoxelNet', 'GroupFree3DNet', 'PointRCNN', 'SMOKEMono3D',
'SASSD', 'MinkSingleStage3DDetector', 'MultiViewDfM', 'DfM' 'SASSD', 'MinkSingleStage3DDetector', 'MultiViewDfM', 'DfM',
'PointVoxelRCNN'
] ]
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Optional
from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import InstanceList
from .two_stage import TwoStage3DDetector
@MODELS.register_module()
class PointVoxelRCNN(TwoStage3DDetector):
r"""PointVoxelRCNN detector.
Please refer to the `PointVoxelRCNN <https://arxiv.org/abs/1912.13192>`_.
Args:
voxel_encoder (dict): Point voxelization encoder layer.
middle_encoder (dict): Middle encoder layer
of points cloud modality.
backbone (dict): Backbone of extracting points features.
neck (dict, optional): Neck of extracting points features.
Defaults to None.
rpn_head (dict, optional): Config of RPN head. Defaults to None.
points_encoder (dict, optional): Points encoder to extract point-wise
features. Defaults to None.
roi_head (dict, optional): Config of ROI head. Defaults to None.
train_cfg (dict, optional): Train config of model.
Defaults to None.
test_cfg (dict, optional): Train config of model.
Defaults to None.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`Det3DDataPreprocessor`. Defaults to None.
"""
def __init__(self,
voxel_encoder: dict,
middle_encoder: dict,
backbone: dict,
neck: Optional[dict] = None,
rpn_head: Optional[dict] = None,
points_encoder: Optional[dict] = None,
roi_head: Optional[dict] = None,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None,
data_preprocessor: Optional[dict] = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg,
data_preprocessor=data_preprocessor)
self.voxel_encoder = MODELS.build(voxel_encoder)
self.middle_encoder = MODELS.build(middle_encoder)
self.points_encoder = MODELS.build(points_encoder)
def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
**kwargs) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
list[:obj:`Det3DDataSample`]: Detection results of the
input samples. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7.
"""
feats_dict = self.extract_feat(batch_inputs_dict)
if self.with_rpn:
rpn_results_list = self.rpn_head.predict(feats_dict,
batch_data_samples)
else:
rpn_results_list = [
data_sample.proposals for data_sample in batch_data_samples
]
# extrack points feats by points_encoder
points_feats_dict = self.extract_points_feat(batch_inputs_dict,
feats_dict,
rpn_results_list)
results_list_3d = self.roi_head.predict(points_feats_dict,
rpn_results_list,
batch_data_samples)
# connvert to Det3DDataSample
results_list = self.add_pred_to_datasample(batch_data_samples,
results_list_3d)
return results_list
def extract_feat(self, batch_inputs_dict: dict) -> dict:
"""Extract features from the input voxels.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
Returns:
dict: We typically obtain a dict of features from the backbone +
neck, it includes:
- spatial_feats (torch.Tensor): Spatial feats from middle
encoder.
- multi_scale_3d_feats (list[torch.Tensor]): Multi scale
middle feats from middle encoder.
- neck_feats (torch.Tensor): Neck feats from neck.
"""
feats_dict = dict()
voxel_dict = batch_inputs_dict['voxels']
voxel_features = self.voxel_encoder(voxel_dict['voxels'],
voxel_dict['num_points'],
voxel_dict['coors'])
batch_size = voxel_dict['coors'][-1, 0].item() + 1
feats_dict['spatial_feats'], feats_dict[
'multi_scale_3d_feats'] = self.middle_encoder(
voxel_features, voxel_dict['coors'], batch_size)
x = self.backbone(feats_dict['spatial_feats'])
if self.with_neck:
neck_feats = self.neck(x)
feats_dict['neck_feats'] = neck_feats
return feats_dict
def extract_points_feat(self, batch_inputs_dict: dict, feats_dict: dict,
rpn_results_list: InstanceList) -> dict:
"""Extract point-wise features from the raw points and voxel features.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
feats_dict (dict): Contains features from the first stage.
rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head.
Returns:
dict: Contain Point-wise features, include:
- keypoints (torch.Tensor): Sampled key points.
- keypoint_features (torch.Tensor): Gather key points features
from multi input.
- fusion_keypoint_features (torch.Tensor): Fusion
keypoint_features by point_feature_fusion_layer.
"""
return self.points_encoder(batch_inputs_dict, feats_dict,
rpn_results_list)
def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
**kwargs):
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
dict: A dictionary of loss components.
"""
feats_dict = self.extract_feat(batch_inputs_dict)
losses = dict()
# RPN forward and loss
if self.with_rpn:
proposal_cfg = self.train_cfg.get('rpn_proposal',
self.test_cfg.rpn)
rpn_data_samples = copy.deepcopy(batch_data_samples)
rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict(
feats_dict,
rpn_data_samples,
proposal_cfg=proposal_cfg,
**kwargs)
# avoid get same name with roi_head loss
keys = rpn_losses.keys()
for key in keys:
if 'loss' in key and 'rpn' not in key:
rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key)
losses.update(rpn_losses)
else:
# TODO: Not support currently, should have a check at Fast R-CNN
assert batch_data_samples[0].get('proposals', None) is not None
# use pre-defined proposals in InstanceData for the second stage
# to extract ROI features.
rpn_results_list = [
data_sample.proposals for data_sample in batch_data_samples
]
points_feats_dict = self.extract_points_feat(batch_inputs_dict,
feats_dict,
rpn_results_list)
roi_losses = self.roi_head.loss(points_feats_dict, rpn_results_list,
batch_data_samples)
losses.update(roi_losses)
return losses
...@@ -4,9 +4,10 @@ from .paconv_sa_module import (PAConvCUDASAModule, PAConvCUDASAModuleMSG, ...@@ -4,9 +4,10 @@ from .paconv_sa_module import (PAConvCUDASAModule, PAConvCUDASAModuleMSG,
PAConvSAModule, PAConvSAModuleMSG) PAConvSAModule, PAConvSAModuleMSG)
from .point_fp_module import PointFPModule from .point_fp_module import PointFPModule
from .point_sa_module import PointSAModule, PointSAModuleMSG from .point_sa_module import PointSAModule, PointSAModuleMSG
from .stack_point_sa_module import StackedSAModuleMSG
__all__ = [ __all__ = [
'build_sa_module', 'PointSAModuleMSG', 'PointSAModule', 'PointFPModule', 'build_sa_module', 'PointSAModuleMSG', 'PointSAModule', 'PointFPModule',
'PAConvSAModule', 'PAConvSAModuleMSG', 'PAConvCUDASAModule', 'PAConvSAModule', 'PAConvSAModuleMSG', 'PAConvCUDASAModule',
'PAConvCUDASAModuleMSG' 'PAConvCUDASAModuleMSG', 'StackedSAModuleMSG'
] ]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.ops import ball_query, grouping_operation
from mmengine.model import BaseModule
from torch import Tensor
from mmdet3d.registry import MODELS
class StackQueryAndGroup(BaseModule):
"""Find nearby points in spherical space.
Args:
radius (float): List of radius in each ball query.
sample_nums (int): Number of samples in each ball query.
use_xyz (bool): Whether to use xyz. Default: True.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
"""
def __init__(self,
radius: float,
sample_nums: int,
use_xyz: bool = True,
init_cfg: dict = None):
super().__init__(init_cfg=init_cfg)
self.radius, self.sample_nums, self.use_xyz = \
radius, sample_nums, use_xyz
def forward(self,
xyz: torch.Tensor,
xyz_batch_cnt: torch.Tensor,
new_xyz: torch.Tensor,
new_xyz_batch_cnt: torch.Tensor,
features: torch.Tensor = None) -> Tuple[Tensor, Tensor]:
"""Forward.
Args:
xyz (Tensor): Tensor of the xyz coordinates
of the features shape with (N1 + N2 ..., 3).
xyz_batch_cnt: (Tensor): Stacked input xyz coordinates nums in
each batch, just like (N1, N2, ...).
new_xyz (Tensor): New coords of the outputs shape with
(M1 + M2 ..., 3).
new_xyz_batch_cnt: (Tensor): Stacked new xyz coordinates nums
in each batch, just like (M1, M2, ...).
features (Tensor, optional): Features of each point with shape
(N1 + N2 ..., C). C is features channel number. Default: None.
"""
assert xyz.shape[0] == xyz_batch_cnt.sum(
), f'xyz: {str(xyz.shape)}, xyz_batch_cnt: str(new_xyz_batch_cnt)'
assert new_xyz.shape[0] == new_xyz_batch_cnt.sum(), \
'new_xyz: str(new_xyz.shape), new_xyz_batch_cnt: ' \
'str(new_xyz_batch_cnt)'
# idx: (M1 + M2 ..., nsample), empty_ball_mask: (M1 + M2 ...)
idx, empty_ball_mask = ball_query(0, self.radius, self.sample_nums,
xyz, new_xyz, xyz_batch_cnt,
new_xyz_batch_cnt)
grouped_xyz = grouping_operation(
xyz, idx, xyz_batch_cnt,
new_xyz_batch_cnt) # (M1 + M2, 3, nsample)
grouped_xyz -= new_xyz.unsqueeze(-1)
grouped_xyz[empty_ball_mask] = 0
if features is not None:
grouped_features = grouping_operation(
features, idx, xyz_batch_cnt,
new_xyz_batch_cnt) # (M1 + M2, C, nsample)
grouped_features[empty_ball_mask] = 0
if self.use_xyz:
new_features = torch.cat(
[grouped_xyz, grouped_features],
dim=1) # (M1 + M2 ..., C + 3, nsample)
else:
new_features = grouped_features
else:
assert self.use_xyz, 'Cannot have not features and not' \
' use xyz as a feature!'
new_features = grouped_xyz
return new_features, idx
@MODELS.register_module()
class StackedSAModuleMSG(BaseModule):
"""Stack point set abstraction module.
Args:
in_channels (int): Input channels.
radius (list[float]): List of radius in each ball query.
sample_nums (list[int]): Number of samples in each ball query.
mlp_channels (list[list[int]]): Specify mlp channels of the
pointnet before the global pooling for each scale to encode
point features.
use_xyz (bool): Whether to use xyz. Default: True.
pool_mod (str): Type of pooling method.
Default: 'max_pool'.
norm_cfg (dict): Type of normalization method. Defaults to
dict(type='BN2d', eps=1e-5, momentum=0.01).
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
"""
def __init__(self,
in_channels: int,
radius: List[float],
sample_nums: List[int],
mlp_channels: List[List[int]],
use_xyz: bool = True,
pool_mod='max',
norm_cfg: dict = dict(type='BN2d', eps=1e-5, momentum=0.01),
init_cfg: dict = None,
**kwargs) -> None:
super(StackedSAModuleMSG, self).__init__(init_cfg=init_cfg)
assert len(radius) == len(sample_nums) == len(mlp_channels)
self.groupers = nn.ModuleList()
self.mlps = nn.ModuleList()
for i in range(len(radius)):
cin = in_channels
if use_xyz:
cin += 3
cur_radius = radius[i]
nsample = sample_nums[i]
mlp_spec = mlp_channels[i]
self.groupers.append(
StackQueryAndGroup(cur_radius, nsample, use_xyz=use_xyz))
mlp = nn.Sequential()
for i in range(len(mlp_spec)):
cout = mlp_spec[i]
mlp.add_module(
f'layer{i}',
ConvModule(
cin,
cout,
kernel_size=(1, 1),
stride=(1, 1),
conv_cfg=dict(type='Conv2d'),
norm_cfg=norm_cfg,
bias=False))
cin = cout
self.mlps.append(mlp)
self.pool_mod = pool_mod
def forward(self,
xyz: Tensor,
xyz_batch_cnt: Tensor,
new_xyz: Tensor,
new_xyz_batch_cnt: Tensor,
features: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""Forward.
Args:
xyz (Tensor): Tensor of the xyz coordinates
of the features shape with (N1 + N2 ..., 3).
xyz_batch_cnt: (Tensor): Stacked input xyz coordinates nums in
each batch, just like (N1, N2, ...).
new_xyz (Tensor): New coords of the outputs shape with
(M1 + M2 ..., 3).
new_xyz_batch_cnt: (Tensor): Stacked new xyz coordinates nums
in each batch, just like (M1, M2, ...).
features (Tensor, optional): Features of each point with shape
(N1 + N2 ..., C). C is features channel number. Default: None.
Returns:
Return new points coordinates and features:
- new_xyz (Tensor): Target points coordinates with shape
(N1 + N2 ..., 3).
- new_features (Tensor): Target points features with shape
(M1 + M2 ..., sum_k(mlps[k][-1])).
"""
new_features_list = []
for k in range(len(self.groupers)):
grouped_features, ball_idxs = self.groupers[k](
xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt,
features) # (M1 + M2, Cin, nsample)
grouped_features = grouped_features.permute(1, 0,
2).unsqueeze(dim=0)
new_features = self.mlps[k](grouped_features)
# (M1 + M2 ..., Cout, nsample)
if self.pool_mod == 'max':
new_features = new_features.max(-1).values
elif self.pool_mod == 'avg':
new_features = new_features.mean(-1)
else:
raise NotImplementedError
new_features = new_features.squeeze(dim=0).permute(1, 0)
new_features_list.append(new_features)
new_features = torch.cat(new_features_list, dim=1)
return new_xyz, new_features
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
from .pillar_scatter import PointPillarsScatter from .pillar_scatter import PointPillarsScatter
from .sparse_encoder import SparseEncoder, SparseEncoderSASSD from .sparse_encoder import SparseEncoder, SparseEncoderSASSD
from .sparse_unet import SparseUNet from .sparse_unet import SparseUNet
from .voxel_set_abstraction import VoxelSetAbstraction
__all__ = [ __all__ = [
'PointPillarsScatter', 'SparseEncoder', 'SparseEncoderSASSD', 'SparseUNet' 'PointPillarsScatter', 'SparseEncoder', 'SparseEncoderSASSD', 'SparseUNet',
'VoxelSetAbstraction'
] ]
...@@ -41,6 +41,8 @@ class SparseEncoder(nn.Module): ...@@ -41,6 +41,8 @@ class SparseEncoder(nn.Module):
Defaults to ((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1)). Defaults to ((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1)).
block_type (str, optional): Type of the block to use. block_type (str, optional): Type of the block to use.
Defaults to 'conv_module'. Defaults to 'conv_module'.
return_middle_feats (bool): Whether output middle features.
Default to False.
""" """
def __init__(self, def __init__(self,
...@@ -54,7 +56,8 @@ class SparseEncoder(nn.Module): ...@@ -54,7 +56,8 @@ class SparseEncoder(nn.Module):
64)), 64)),
encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
1)), 1)),
block_type='conv_module'): block_type='conv_module',
return_middle_feats=False):
super().__init__() super().__init__()
assert block_type in ['conv_module', 'basicblock'] assert block_type in ['conv_module', 'basicblock']
self.sparse_shape = sparse_shape self.sparse_shape = sparse_shape
...@@ -66,6 +69,7 @@ class SparseEncoder(nn.Module): ...@@ -66,6 +69,7 @@ class SparseEncoder(nn.Module):
self.encoder_paddings = encoder_paddings self.encoder_paddings = encoder_paddings
self.stage_num = len(self.encoder_channels) self.stage_num = len(self.encoder_channels)
self.fp16_enabled = False self.fp16_enabled = False
self.return_middle_feats = return_middle_feats
# Spconv init all weight on its own # Spconv init all weight on its own
assert isinstance(order, tuple) and len(order) == 3 assert isinstance(order, tuple) and len(order) == 3
...@@ -117,7 +121,14 @@ class SparseEncoder(nn.Module): ...@@ -117,7 +121,14 @@ class SparseEncoder(nn.Module):
batch_size (int): Batch size. batch_size (int): Batch size.
Returns: Returns:
dict: Backbone features. torch.Tensor | tuple[torch.Tensor, list]: Return spatial features
include:
- spatial_features (torch.Tensor): Spatial features are out from
the last layer.
- encode_features (List[SparseConvTensor], optional): Middle layer
output features. When self.return_middle_feats is True, the
module returns middle features.
""" """
coors = coors.int() coors = coors.int()
input_sp_tensor = SparseConvTensor(voxel_features, coors, input_sp_tensor = SparseConvTensor(voxel_features, coors,
...@@ -137,7 +148,10 @@ class SparseEncoder(nn.Module): ...@@ -137,7 +148,10 @@ class SparseEncoder(nn.Module):
N, C, D, H, W = spatial_features.shape N, C, D, H, W = spatial_features.shape
spatial_features = spatial_features.view(N, C * D, H, W) spatial_features = spatial_features.view(N, C * D, H, W)
return spatial_features if self.return_middle_feats:
return spatial_features, encode_features
else:
return spatial_features
def make_encoder_layers(self, def make_encoder_layers(self,
make_block, make_block,
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import mmengine
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.ops.furthest_point_sample import furthest_point_sample
from mmengine.model import BaseModule
from mmdet3d.registry import MODELS
from mmdet3d.utils import InstanceList
def bilinear_interpolate_torch(inputs, x, y):
"""Bilinear interpolate for inputs."""
x0 = torch.floor(x).long()
x1 = x0 + 1
y0 = torch.floor(y).long()
y1 = y0 + 1
x0 = torch.clamp(x0, 0, inputs.shape[1] - 1)
x1 = torch.clamp(x1, 0, inputs.shape[1] - 1)
y0 = torch.clamp(y0, 0, inputs.shape[0] - 1)
y1 = torch.clamp(y1, 0, inputs.shape[0] - 1)
Ia = inputs[y0, x0]
Ib = inputs[y1, x0]
Ic = inputs[y0, x1]
Id = inputs[y1, x1]
wa = (x1.type_as(x) - x) * (y1.type_as(y) - y)
wb = (x1.type_as(x) - x) * (y - y0.type_as(y))
wc = (x - x0.type_as(x)) * (y1.type_as(y) - y)
wd = (x - x0.type_as(x)) * (y - y0.type_as(y))
ans = torch.t((torch.t(Ia) * wa)) + torch.t(torch.t(Ib) * wb) + torch.t(
torch.t(Ic) * wc) + torch.t(torch.t(Id) * wd)
return ans
@MODELS.register_module()
class VoxelSetAbstraction(BaseModule):
"""Voxel set abstraction module for PVRCNN and PVRCNN++.
Args:
num_keypoints (int): The number of key points sampled from
raw points cloud.
fused_out_channel (int): Key points feature output channels
num after fused. Default to 128.
voxel_size (list[float]): Size of voxels. Defaults to
[0.05, 0.05, 0.1].
point_cloud_range (list[float]): Point cloud range. Defaults to
[0, -40, -3, 70.4, 40, 1].
voxel_sa_cfgs_list (List[dict or ConfigDict], optional): List of SA
module cfg. Used to gather key points features from multi-wise
voxel features. Default to None.
rawpoints_sa_cfgs (dict or ConfigDict, optional): SA module cfg.
Used to gather key points features from raw points. Default to
None.
bev_feat_channel (int): Bev features channels num.
Default to 256.
bev_scale_factor (int): Bev features scale factor. Default to 8.
voxel_center_as_source (bool): Whether used voxel centers as points
cloud key points. Defaults to False.
norm_cfg (dict[str]): Config of normalization layer. Default
used dict(type='BN1d', eps=1e-5, momentum=0.1).
bias (bool | str, optional): If specified as `auto`, it will be
decided by `norm_cfg`. `bias` will be set as True if
`norm_cfg` is None, otherwise False. Default: 'auto'.
"""
def __init__(self,
num_keypoints: int,
fused_out_channel: int = 128,
voxel_size: list = [0.05, 0.05, 0.1],
point_cloud_range: list = [0, -40, -3, 70.4, 40, 1],
voxel_sa_cfgs_list: Optional[list] = None,
rawpoints_sa_cfgs: Optional[dict] = None,
bev_feat_channel: int = 256,
bev_scale_factor: int = 8,
voxel_center_as_source: bool = False,
norm_cfg: dict = dict(type='BN2d', eps=1e-5, momentum=0.1),
bias: str = 'auto') -> None:
super().__init__()
self.num_keypoints = num_keypoints
self.fused_out_channel = fused_out_channel
self.voxel_size = voxel_size
self.point_cloud_range = point_cloud_range
self.voxel_center_as_source = voxel_center_as_source
gathered_channel = 0
if rawpoints_sa_cfgs is not None:
self.rawpoints_sa_layer = MODELS.build(rawpoints_sa_cfgs)
gathered_channel += sum(
[x[-1] for x in rawpoints_sa_cfgs.mlp_channels])
else:
self.rawpoints_sa_layer = None
if voxel_sa_cfgs_list is not None:
self.voxel_sa_configs_list = voxel_sa_cfgs_list
self.voxel_sa_layers = nn.ModuleList()
for voxel_sa_config in voxel_sa_cfgs_list:
cur_layer = MODELS.build(voxel_sa_config)
self.voxel_sa_layers.append(cur_layer)
gathered_channel += sum(
[x[-1] for x in voxel_sa_config.mlp_channels])
else:
self.voxel_sa_layers = None
if bev_feat_channel is not None and bev_scale_factor is not None:
self.bev_cfg = mmengine.Config(
dict(
bev_feat_channels=bev_feat_channel,
bev_scale_factor=bev_scale_factor))
gathered_channel += bev_feat_channel
else:
self.bev_cfg = None
self.point_feature_fusion_layer = nn.Sequential(
ConvModule(
gathered_channel,
fused_out_channel,
kernel_size=(1, 1),
stride=(1, 1),
conv_cfg=dict(type='Conv2d'),
norm_cfg=norm_cfg,
bias=bias))
def interpolate_from_bev_features(self, keypoints: torch.Tensor,
bev_features: torch.Tensor,
batch_size: int,
bev_scale_factor: int) -> torch.Tensor:
"""Gather key points features from bev feature map by interpolate.
Args:
keypoints (torch.Tensor): Sampled key points with shape
(N1 + N2 + ..., NDim).
bev_features (torch.Tensor): Bev feature map from the first
stage with shape (B, C, H, W).
batch_size (int): Input batch size.
bev_scale_factor (int): Bev feature map scale factor.
Returns:
torch.Tensor: Key points features gather from bev feature
map with shape (N1 + N2 + ..., C)
"""
x_idxs = (keypoints[..., 0] -
self.point_cloud_range[0]) / self.voxel_size[0]
y_idxs = (keypoints[..., 1] -
self.point_cloud_range[1]) / self.voxel_size[1]
x_idxs = x_idxs / bev_scale_factor
y_idxs = y_idxs / bev_scale_factor
point_bev_features_list = []
for k in range(batch_size):
cur_x_idxs = x_idxs[k, ...]
cur_y_idxs = y_idxs[k, ...]
cur_bev_features = bev_features[k].permute(1, 2, 0) # (H, W, C)
point_bev_features = bilinear_interpolate_torch(
cur_bev_features, cur_x_idxs, cur_y_idxs)
point_bev_features_list.append(point_bev_features)
point_bev_features = torch.cat(
point_bev_features_list, dim=0) # (N1 + N2 + ..., C)
return point_bev_features.view(batch_size, keypoints.shape[1], -1)
def get_voxel_centers(self, coors: torch.Tensor,
scale_factor: float) -> torch.Tensor:
"""Get voxel centers coordinate.
Args:
coors (torch.Tensor): Coordinates of voxels shape is Nx(1+NDim),
where 1 represents the batch index.
scale_factor (float): Scale factor.
Returns:
torch.Tensor: Voxel centers coordinate with shape (N, 3).
"""
assert coors.shape[1] == 4
voxel_centers = coors[:, [3, 2, 1]].float() # (xyz)
voxel_size = torch.tensor(
self.voxel_size,
device=voxel_centers.device).float() * scale_factor
pc_range = torch.tensor(
self.point_cloud_range[0:3], device=voxel_centers.device).float()
voxel_centers = (voxel_centers + 0.5) * voxel_size + pc_range
return voxel_centers
def sample_key_points(self, points: List[torch.Tensor],
coors: torch.Tensor) -> torch.Tensor:
"""Sample key points from raw points cloud.
Args:
points (List[torch.Tensor]): Point cloud of each sample.
coors (torch.Tensor): Coordinates of voxels shape is Nx(1+NDim),
where 1 represents the batch index.
Returns:
torch.Tensor: (B, M, 3) Key points of each sample.
M is num_keypoints.
"""
assert points is not None or coors is not None
if self.voxel_center_as_source:
_src_points = self.get_voxel_centers(coors=coors, scale_factor=1)
batch_size = coors[-1, 0].item() + 1
src_points = [
_src_points[coors[:, 0] == b] for b in range(batch_size)
]
else:
src_points = [p[..., :3] for p in points]
keypoints_list = []
for points_to_sample in src_points:
num_points = points_to_sample.shape[0]
cur_pt_idxs = furthest_point_sample(
points_to_sample.unsqueeze(dim=0).contiguous(),
self.num_keypoints).long()[0]
if num_points < self.num_keypoints:
times = int(self.num_keypoints / num_points) + 1
non_empty = cur_pt_idxs[:num_points]
cur_pt_idxs = non_empty.repeat(times)[:self.num_keypoints]
keypoints = points_to_sample[cur_pt_idxs]
keypoints_list.append(keypoints)
keypoints = torch.stack(keypoints_list, dim=0) # (B, M, 3)
return keypoints
def forward(self, batch_inputs_dict: dict, feats_dict: dict,
rpn_results_list: InstanceList) -> dict:
"""Extract point-wise features from multi-input.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
feats_dict (dict): Contains features from the first
stage.
rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head.
Returns:
dict: Contain Point-wise features, include:
- keypoints (torch.Tensor): Sampled key points.
- keypoint_features (torch.Tensor): Gathered key points
features from multi input.
- fusion_keypoint_features (torch.Tensor): Fusion
keypoint_features by point_feature_fusion_layer.
"""
points = batch_inputs_dict['points']
voxel_encode_features = feats_dict['multi_scale_3d_feats']
bev_encode_features = feats_dict['spatial_feats']
if self.voxel_center_as_source:
voxels_coors = batch_inputs_dict['voxels']['coors']
else:
voxels_coors = None
keypoints = self.sample_key_points(points, voxels_coors)
point_features_list = []
batch_size = len(points)
if self.bev_cfg is not None:
point_bev_features = self.interpolate_from_bev_features(
keypoints, bev_encode_features, batch_size,
self.bev_cfg.bev_scale_factor)
point_features_list.append(point_bev_features.contiguous())
batch_size, num_keypoints, _ = keypoints.shape
key_xyz = keypoints.view(-1, 3)
key_xyz_batch_cnt = key_xyz.new_zeros(batch_size).int().fill_(
num_keypoints)
if self.rawpoints_sa_layer is not None:
batch_points = torch.cat(points, dim=0)
batch_cnt = [len(p) for p in points]
xyz = batch_points[:, :3].contiguous()
features = None
if batch_points.size(1) > 0:
features = batch_points[:, 3:].contiguous()
xyz_batch_cnt = xyz.new_tensor(batch_cnt, dtype=torch.int32)
pooled_points, pooled_features = self.rawpoints_sa_layer(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=key_xyz.contiguous(),
new_xyz_batch_cnt=key_xyz_batch_cnt,
features=features.contiguous(),
)
point_features_list.append(pooled_features.contiguous().view(
batch_size, num_keypoints, -1))
if self.voxel_sa_layers is not None:
for k, voxel_sa_layer in enumerate(self.voxel_sa_layers):
cur_coords = voxel_encode_features[k].indices
xyz = self.get_voxel_centers(
coors=cur_coords,
scale_factor=self.voxel_sa_configs_list[k].scale_factor
).contiguous()
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (cur_coords[:, 0] == bs_idx).sum()
pooled_points, pooled_features = voxel_sa_layer(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=key_xyz.contiguous(),
new_xyz_batch_cnt=key_xyz_batch_cnt,
features=voxel_encode_features[k].features.contiguous(),
)
point_features_list.append(pooled_features.contiguous().view(
batch_size, num_keypoints, -1))
point_features = torch.cat(
point_features_list, dim=-1).view(batch_size * num_keypoints, -1,
1)
fusion_point_features = self.point_feature_fusion_layer(
point_features.unsqueeze(dim=-1)).squeeze(dim=-1)
batch_idxs = torch.arange(
batch_size * num_keypoints, device=keypoints.device
) // num_keypoints # batch indexes of each key points
batch_keypoints_xyz = torch.cat(
(batch_idxs.to(key_xyz.dtype).unsqueeze(dim=-1), key_xyz), dim=-1)
return dict(
keypoint_features=point_features.squeeze(dim=-1),
fusion_keypoint_features=fusion_point_features.squeeze(dim=-1),
keypoints=batch_keypoints_xyz)
...@@ -5,10 +5,11 @@ from .h3d_roi_head import H3DRoIHead ...@@ -5,10 +5,11 @@ from .h3d_roi_head import H3DRoIHead
from .mask_heads import PointwiseSemanticHead, PrimitiveHead from .mask_heads import PointwiseSemanticHead, PrimitiveHead
from .part_aggregation_roi_head import PartAggregationROIHead from .part_aggregation_roi_head import PartAggregationROIHead
from .point_rcnn_roi_head import PointRCNNRoIHead from .point_rcnn_roi_head import PointRCNNRoIHead
from .pv_rcnn_roi_head import PVRCNNRoiHead
from .roi_extractors import Single3DRoIAwareExtractor, SingleRoIExtractor from .roi_extractors import Single3DRoIAwareExtractor, SingleRoIExtractor
__all__ = [ __all__ = [
'Base3DRoIHead', 'PartAggregationROIHead', 'PointwiseSemanticHead', 'Base3DRoIHead', 'PartAggregationROIHead', 'PointwiseSemanticHead',
'Single3DRoIAwareExtractor', 'PartA2BboxHead', 'SingleRoIExtractor', 'Single3DRoIAwareExtractor', 'PartA2BboxHead', 'SingleRoIExtractor',
'H3DRoIHead', 'PrimitiveHead', 'PointRCNNRoIHead' 'H3DRoIHead', 'PrimitiveHead', 'PointRCNNRoIHead', 'PVRCNNRoiHead'
] ]
...@@ -7,9 +7,10 @@ from mmdet.models.roi_heads.bbox_heads import (BBoxHead, ConvFCBBoxHead, ...@@ -7,9 +7,10 @@ from mmdet.models.roi_heads.bbox_heads import (BBoxHead, ConvFCBBoxHead,
from .h3d_bbox_head import H3DBboxHead from .h3d_bbox_head import H3DBboxHead
from .parta2_bbox_head import PartA2BboxHead from .parta2_bbox_head import PartA2BboxHead
from .point_rcnn_bbox_head import PointRCNNBboxHead from .point_rcnn_bbox_head import PointRCNNBboxHead
from .pv_rcnn_bbox_head import PVRCNNBBoxHead
__all__ = [ __all__ = [
'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead', 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead',
'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'PartA2BboxHead', 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'PartA2BboxHead',
'H3DBboxHead', 'PointRCNNBboxHead' 'H3DBboxHead', 'PointRCNNBboxHead', 'PVRCNNBBoxHead'
] ]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch import nn as nn
from mmdet3d.models.builder import build_loss
from mmdet3d.models.layers import nms_bev, nms_normal_bev
from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures.bbox_3d import (LiDARInstance3DBoxes,
rotation_3d_in_axis, xywhr2xyxyr)
from mmdet3d.utils import InstanceList
from mmdet.models.task_modules.samplers import SamplingResult
from mmdet.models.utils import multi_apply
@MODELS.register_module()
class PVRCNNBBoxHead(BaseModule):
"""PVRCNN BBox head.
Args:
in_channels (int): The number of input channel.
grid_size (int): The number of grid points in roi bbox.
num_classes (int): The number of classes.
class_agnostic (bool): Whether generate class agnostic prediction.
Defaults to True.
shared_fc_channels (tuple(int)): Out channels of each shared fc layer.
Defaults to (256, 256).
cls_channels (tuple(int)): Out channels of each classification layer.
Defaults to (256, 256).
reg_channels (tuple(int)): Out channels of each regression layer.
Defaults to (256, 256).
dropout_ratio (float): Ratio of dropout layer. Defaults to 0.5.
with_corner_loss (bool): Whether to use corner loss or not.
Defaults to True.
bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for box head.
Defaults to dict(type='DeltaXYZWLHRBBoxCoder').
norm_cfg (dict): Type of normalization method.
Defaults to dict(type='BN1d', eps=1e-5, momentum=0.1)
loss_bbox (dict): Config dict of box regression loss.
loss_cls (dict): Config dict of classifacation loss.
init_cfg (dict, optional): Initialize config of
model.
"""
def __init__(
self,
in_channels: int,
grid_size: int,
num_classes: int,
class_agnostic: bool = True,
shared_fc_channels: Tuple[int] = (256, 256),
cls_channels: Tuple[int] = (256, 256),
reg_channels: Tuple[int] = (256, 256),
dropout_ratio: float = 0.3,
with_corner_loss: bool = True,
bbox_coder: dict = dict(type='DeltaXYZWLHRBBoxCoder'),
norm_cfg: dict = dict(type='BN2d', eps=1e-5, momentum=0.1),
loss_bbox: dict = dict(
type='mmdet.SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_cls: dict = dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='none',
loss_weight=1.0),
init_cfg: Optional[dict] = dict(
type='Xavier', layer=['Conv2d', 'Conv1d'], distribution='uniform')
) -> None:
super(PVRCNNBBoxHead, self).__init__(init_cfg=init_cfg)
self.init_cfg = init_cfg
self.num_classes = num_classes
self.with_corner_loss = with_corner_loss
self.class_agnostic = class_agnostic
self.bbox_coder = TASK_UTILS.build(bbox_coder)
self.loss_bbox = build_loss(loss_bbox)
self.loss_cls = build_loss(loss_cls)
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
cls_out_channels = 1 if class_agnostic else num_classes
self.reg_out_channels = self.bbox_coder.code_size * cls_out_channels
if self.use_sigmoid_cls:
self.cls_out_channels = cls_out_channels
else:
self.cls_out_channels = cls_out_channels + 1
self.dropout_ratio = dropout_ratio
self.grid_size = grid_size
# PVRCNNBBoxHead model in_channels is num of grid points in roi box.
in_channels *= (self.grid_size**3)
self.in_channels = in_channels
self.shared_fc_layer = self._make_fc_layers(
in_channels, shared_fc_channels,
range(len(shared_fc_channels) - 1), norm_cfg)
self.cls_layer = self._make_fc_layers(
shared_fc_channels[-1],
cls_channels,
range(1),
norm_cfg,
out_channels=self.cls_out_channels)
self.reg_layer = self._make_fc_layers(
shared_fc_channels[-1],
reg_channels,
range(1),
norm_cfg,
out_channels=self.reg_out_channels)
def _make_fc_layers(self,
in_channels: int,
fc_channels: list,
dropout_indices: list,
norm_cfg: dict,
out_channels: Optional[int] = None) -> torch.nn.Module:
"""Initial a full connection layer.
Args:
in_channels (int): Module in channels.
fc_channels (list): Full connection layer channels.
dropout_indices (list): Dropout indices.
norm_cfg (dict): Type of normalization method.
out_channels (int, optional): Module out channels.
"""
fc_layers = []
pre_channel = in_channels
for k in range(len(fc_channels)):
fc_layers.append(
ConvModule(
pre_channel,
fc_channels[k],
kernel_size=(1, 1),
stride=(1, 1),
norm_cfg=norm_cfg,
conv_cfg=dict(type='Conv2d'),
bias=False,
inplace=True))
pre_channel = fc_channels[k]
if self.dropout_ratio >= 0 and k in dropout_indices:
fc_layers.append(nn.Dropout(self.dropout_ratio))
if out_channels is not None:
fc_layers.append(
nn.Conv2d(fc_channels[-1], out_channels, 1, bias=True))
fc_layers = nn.Sequential(*fc_layers)
return fc_layers
def forward(self, feats: torch.Tensor) -> Tuple[torch.Tensor]:
"""Forward pvrcnn bbox head.
Args:
feats (torch.Tensor): Batch point-wise features.
Returns:
tuple[torch.Tensor]: Score of class and bbox predictions.
"""
# (B * N, 6, 6, 6, C)
rcnn_batch_size = feats.shape[0]
feats = feats.permute(0, 4, 1, 2,
3).contiguous().view(rcnn_batch_size, -1, 1, 1)
# (BxN, C*6*6*6)
shared_feats = self.shared_fc_layer(feats)
cls_score = self.cls_layer(shared_feats).transpose(
1, 2).contiguous().view(-1, self.cls_out_channels) # (B, 1)
bbox_pred = self.reg_layer(shared_feats).transpose(
1, 2).contiguous().view(-1, self.reg_out_channels) # (B, C)
return cls_score, bbox_pred
def loss(self, cls_score: torch.Tensor, bbox_pred: torch.Tensor,
rois: torch.Tensor, labels: torch.Tensor,
bbox_targets: torch.Tensor, pos_gt_bboxes: torch.Tensor,
reg_mask: torch.Tensor, label_weights: torch.Tensor,
bbox_weights: torch.Tensor) -> Dict:
"""Coumputing losses.
Args:
cls_score (torch.Tensor): Scores of each roi.
bbox_pred (torch.Tensor): Predictions of bboxes.
rois (torch.Tensor): Roi bboxes.
labels (torch.Tensor): Labels of class.
bbox_targets (torch.Tensor): Target of positive bboxes.
pos_gt_bboxes (torch.Tensor): Ground truths of positive bboxes.
reg_mask (torch.Tensor): Mask for positive bboxes.
label_weights (torch.Tensor): Weights of class loss.
bbox_weights (torch.Tensor): Weights of bbox loss.
Returns:
dict: Computed losses.
- loss_cls (torch.Tensor): Loss of classes.
- loss_bbox (torch.Tensor): Loss of bboxes.
- loss_corner (torch.Tensor): Loss of corners.
"""
losses = dict()
rcnn_batch_size = cls_score.shape[0]
# calculate class loss
cls_flat = cls_score.view(-1)
loss_cls = self.loss_cls(cls_flat, labels, label_weights)
losses['loss_cls'] = loss_cls
# calculate regression loss
code_size = self.bbox_coder.code_size
pos_inds = (reg_mask > 0)
if pos_inds.any() == 0:
# fake a part loss
losses['loss_bbox'] = 0 * bbox_pred.sum()
if self.with_corner_loss:
losses['loss_corner'] = 0 * bbox_pred.sum()
else:
pos_bbox_pred = bbox_pred.view(rcnn_batch_size, -1)[pos_inds]
bbox_weights_flat = bbox_weights[pos_inds].view(-1, 1).repeat(
1, pos_bbox_pred.shape[-1])
loss_bbox = self.loss_bbox(
pos_bbox_pred.unsqueeze(dim=0), bbox_targets.unsqueeze(dim=0),
bbox_weights_flat.unsqueeze(dim=0))
losses['loss_bbox'] = loss_bbox
if self.with_corner_loss:
pos_roi_boxes3d = rois[..., 1:].view(-1, code_size)[pos_inds]
pos_roi_boxes3d = pos_roi_boxes3d.view(-1, code_size)
batch_anchors = pos_roi_boxes3d.clone().detach()
pos_rois_rotation = pos_roi_boxes3d[..., 6].view(-1)
roi_xyz = pos_roi_boxes3d[..., 0:3].view(-1, 3)
batch_anchors[..., 0:3] = 0
# decode boxes
pred_boxes3d = self.bbox_coder.decode(
batch_anchors,
pos_bbox_pred.view(-1, code_size)).view(-1, code_size)
pred_boxes3d[..., 0:3] = rotation_3d_in_axis(
pred_boxes3d[..., 0:3].unsqueeze(1),
pos_rois_rotation,
axis=2).squeeze(1)
pred_boxes3d[:, 0:3] += roi_xyz
# calculate corner loss
loss_corner = self.get_corner_loss_lidar(
pred_boxes3d, pos_gt_bboxes)
losses['loss_corner'] = loss_corner.mean()
return losses
def get_targets(self,
sampling_results: SamplingResult,
rcnn_train_cfg: dict,
concat: bool = True) -> Tuple[torch.Tensor]:
"""Generate targets.
Args:
sampling_results (list[:obj:`SamplingResult`]):
Sampled results from rois.
rcnn_train_cfg (:obj:`ConfigDict`): Training config of rcnn.
concat (bool): Whether to concatenate targets between batches.
Returns:
tuple[torch.Tensor]: Targets of boxes and class prediction.
"""
pos_bboxes_list = [res.pos_bboxes for res in sampling_results]
pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results]
iou_list = [res.iou for res in sampling_results]
targets = multi_apply(
self._get_target_single,
pos_bboxes_list,
pos_gt_bboxes_list,
iou_list,
cfg=rcnn_train_cfg)
(label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights,
bbox_weights) = targets
if concat:
label = torch.cat(label, 0)
bbox_targets = torch.cat(bbox_targets, 0)
pos_gt_bboxes = torch.cat(pos_gt_bboxes, 0)
reg_mask = torch.cat(reg_mask, 0)
label_weights = torch.cat(label_weights, 0)
label_weights /= torch.clamp(label_weights.sum(), min=1.0)
bbox_weights = torch.cat(bbox_weights, 0)
bbox_weights /= torch.clamp(bbox_weights.sum(), min=1.0)
return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights,
bbox_weights)
def _get_target_single(self, pos_bboxes: torch.Tensor,
pos_gt_bboxes: torch.Tensor, ious: torch.Tensor,
cfg: dict) -> Tuple[torch.Tensor]:
"""Generate training targets for a single sample.
Args:
pos_bboxes (torch.Tensor): Positive boxes with shape
(N, 7).
pos_gt_bboxes (torch.Tensor): Ground truth boxes with shape
(M, 7).
ious (torch.Tensor): IoU between `pos_bboxes` and `pos_gt_bboxes`
in shape (N, M).
cfg (dict): Training configs.
Returns:
tuple[torch.Tensor]: Target for positive boxes.
(label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights,
bbox_weights)
"""
cls_pos_mask = ious > cfg.cls_pos_thr
cls_neg_mask = ious < cfg.cls_neg_thr
interval_mask = (cls_pos_mask == 0) & (cls_neg_mask == 0)
# iou regression target
label = (cls_pos_mask > 0).float()
label[interval_mask] = ious[interval_mask] * 2 - 0.5
# label weights
label_weights = (label >= 0).float()
# box regression target
reg_mask = pos_bboxes.new_zeros(ious.size(0)).long()
reg_mask[0:pos_gt_bboxes.size(0)] = 1
bbox_weights = (reg_mask > 0).float()
if reg_mask.bool().any():
pos_gt_bboxes_ct = pos_gt_bboxes.clone().detach()
roi_center = pos_bboxes[..., 0:3]
roi_ry = pos_bboxes[..., 6] % (2 * np.pi)
# canonical transformation
pos_gt_bboxes_ct[..., 0:3] -= roi_center
pos_gt_bboxes_ct[..., 6] -= roi_ry
pos_gt_bboxes_ct[..., 0:3] = rotation_3d_in_axis(
pos_gt_bboxes_ct[..., 0:3].unsqueeze(1), -roi_ry,
axis=2).squeeze(1)
# flip orientation if rois have opposite orientation
ry_label = pos_gt_bboxes_ct[..., 6] % (2 * np.pi) # 0 ~ 2pi
opposite_flag = (ry_label > np.pi * 0.5) & (ry_label < np.pi * 1.5)
ry_label[opposite_flag] = (ry_label[opposite_flag] + np.pi) % (
2 * np.pi) # (0 ~ pi/2, 3pi/2 ~ 2pi)
flag = ry_label > np.pi
ry_label[flag] = ry_label[flag] - np.pi * 2 # (-pi/2, pi/2)
ry_label = torch.clamp(ry_label, min=-np.pi / 2, max=np.pi / 2)
pos_gt_bboxes_ct[..., 6] = ry_label
rois_anchor = pos_bboxes.clone().detach()
rois_anchor[:, 0:3] = 0
rois_anchor[:, 6] = 0
bbox_targets = self.bbox_coder.encode(rois_anchor,
pos_gt_bboxes_ct)
else:
# no fg bbox
bbox_targets = pos_gt_bboxes.new_empty((0, 7))
return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights,
bbox_weights)
def get_corner_loss_lidar(self,
pred_bbox3d: torch.Tensor,
gt_bbox3d: torch.Tensor,
delta: float = 1.0) -> torch.Tensor:
"""Calculate corner loss of given boxes.
Args:
pred_bbox3d (torch.FloatTensor): Predicted boxes in shape (N, 7).
gt_bbox3d (torch.FloatTensor): Ground truth boxes in shape (N, 7).
delta (float, optional): huber loss threshold. Defaults to 1.0
Returns:
torch.FloatTensor: Calculated corner loss in shape (N).
"""
assert pred_bbox3d.shape[0] == gt_bbox3d.shape[0]
# This is a little bit hack here because we assume the box for
# Part-A2 is in LiDAR coordinates
gt_boxes_structure = LiDARInstance3DBoxes(gt_bbox3d)
pred_box_corners = LiDARInstance3DBoxes(pred_bbox3d).corners
gt_box_corners = gt_boxes_structure.corners
# This flip only changes the heading direction of GT boxes
gt_bbox3d_flip = gt_boxes_structure.clone()
gt_bbox3d_flip.tensor[:, 6] += np.pi
gt_box_corners_flip = gt_bbox3d_flip.corners
corner_dist = torch.min(
torch.norm(pred_box_corners - gt_box_corners, dim=2),
torch.norm(pred_box_corners - gt_box_corners_flip,
dim=2)) # (N, 8)
# huber loss
abs_error = torch.abs(corner_dist)
corner_loss = torch.where(abs_error < delta,
0.5 * abs_error**2 / delta,
abs_error - 0.5 * delta)
return corner_loss.mean(dim=1)
def get_results(self,
rois: torch.Tensor,
cls_preds: torch.Tensor,
bbox_reg: torch.Tensor,
class_labels: torch.Tensor,
input_metas: List[dict],
test_cfg: dict = None) -> InstanceList:
"""Generate bboxes from bbox head predictions.
Args:
rois (torch.Tensor): Roi bounding boxes.
cls_preds (torch.Tensor): Scores of bounding boxes.
bbox_reg (torch.Tensor): Bounding boxes predictions
class_labels (torch.Tensor): Label of classes
input_metas (list[dict]): Point cloud meta info.
test_cfg (:obj:`ConfigDict`): Testing config.
Returns:
list[:obj:`InstanceData`]: Detection results of each sample
after the post process.
Each item usually contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes,
contains a tensor with shape (num_instances, C), where
C >= 7.
"""
roi_batch_id = rois[..., 0]
roi_boxes = rois[..., 1:] # boxes without batch id
batch_size = int(roi_batch_id.max().item() + 1)
# decode boxes
roi_ry = roi_boxes[..., 6].view(-1)
roi_xyz = roi_boxes[..., 0:3].view(-1, 3)
local_roi_boxes = roi_boxes.clone().detach()
local_roi_boxes[..., 0:3] = 0
batch_box_preds = self.bbox_coder.decode(local_roi_boxes, bbox_reg)
batch_box_preds[..., 0:3] = rotation_3d_in_axis(
batch_box_preds[..., 0:3].unsqueeze(1), roi_ry, axis=2).squeeze(1)
batch_box_preds[:, 0:3] += roi_xyz
# post processing
result_list = []
for batch_id in range(batch_size):
cls_preds = cls_preds[roi_batch_id == batch_id]
box_preds = batch_box_preds[roi_batch_id == batch_id]
label_preds = class_labels[batch_id]
cls_preds = cls_preds.sigmoid()
cls_preds, _ = torch.max(cls_preds, dim=-1)
selected = self.class_agnostic_nms(
scores=cls_preds,
bbox_preds=box_preds,
input_meta=input_metas[batch_id],
nms_cfg=test_cfg)
selected_bboxes = box_preds[selected]
selected_label_preds = label_preds[selected]
selected_scores = cls_preds[selected]
results = InstanceData()
results.bboxes_3d = input_metas[batch_id]['box_type_3d'](
selected_bboxes, self.bbox_coder.code_size)
results.scores_3d = selected_scores
results.labels_3d = selected_label_preds
result_list.append(results)
return result_list
def class_agnostic_nms(self, scores: torch.Tensor,
bbox_preds: torch.Tensor, nms_cfg: dict,
input_meta: dict) -> Tuple[torch.Tensor]:
"""Class agnostic NMS for box head.
Args:
scores (torch.Tensor): Object score of bounding boxes.
bbox_preds (torch.Tensor): Predicted bounding boxes.
nms_cfg (dict): NMS config dict.
input_meta (dict): Contain pcd and img's meta info.
Returns:
tuple[torch.Tensor]: Bounding boxes, scores and labels.
"""
obj_scores = scores.clone()
if nms_cfg.use_rotate_nms:
nms_func = nms_bev
else:
nms_func = nms_normal_bev
bbox = input_meta['box_type_3d'](
bbox_preds.clone(),
box_dim=bbox_preds.shape[-1],
with_yaw=True,
origin=(0.5, 0.5, 0.5))
if nms_cfg.score_thr is not None:
scores_mask = (obj_scores >= nms_cfg.score_thr)
obj_scores = obj_scores[scores_mask]
bbox = bbox[scores_mask]
selected = []
if obj_scores.shape[0] > 0:
box_scores_nms, indices = torch.topk(
obj_scores, k=min(4096, obj_scores.shape[0]))
bbox_bev = bbox.bev[indices]
bbox_for_nms = xywhr2xyxyr(bbox_bev)
keep = nms_func(bbox_for_nms, box_scores_nms, nms_cfg.nms_thr)
selected = indices[keep]
if nms_cfg.score_thr is not None:
original_idxs = scores_mask.nonzero().view(-1)
selected = original_idxs[selected]
return selected
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .foreground_segmentation_head import ForegroundSegmentationHead
from .pointwise_semantic_head import PointwiseSemanticHead from .pointwise_semantic_head import PointwiseSemanticHead
from .primitive_head import PrimitiveHead from .primitive_head import PrimitiveHead
__all__ = ['PointwiseSemanticHead', 'PrimitiveHead'] __all__ = [
'PointwiseSemanticHead', 'PrimitiveHead', 'ForegroundSegmentationHead'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple
import torch
from mmcv.cnn.bricks import build_norm_layer
from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch import nn as nn
from mmdet3d.models.builder import build_loss
from mmdet3d.registry import MODELS
from mmdet3d.utils import InstanceList
from mmdet.models.utils import multi_apply
@MODELS.register_module()
class ForegroundSegmentationHead(BaseModule):
"""Foreground segmentation head.
Args:
in_channels (int): The number of input channel.
mlp_channels (tuple[int]): Specify of mlp channels. Defaults
to (256, 256).
extra_width (float): Boxes enlarge width. Default used 0.1.
norm_cfg (dict): Type of normalization method. Defaults to
dict(type='BN1d', eps=1e-5, momentum=0.1).
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
loss_seg (dict): Config of segmentation loss. Defaults to
dict(type='mmdet.FocalLoss')
"""
def __init__(
self,
in_channels: int,
mlp_channels: Tuple[int] = (256, 256),
extra_width: float = 0.1,
norm_cfg: dict = dict(type='BN1d', eps=1e-5, momentum=0.1),
init_cfg: Optional[dict] = None,
loss_seg: dict = dict(
type='mmdet.FocalLoss',
use_sigmoid=True,
reduction='sum',
gamma=2.0,
alpha=0.25,
activated=True,
loss_weight=1.0)
) -> None:
super(ForegroundSegmentationHead, self).__init__(init_cfg=init_cfg)
self.extra_width = extra_width
self.num_classes = 1
self.in_channels = in_channels
self.use_sigmoid_cls = loss_seg.get('use_sigmoid', False)
out_channels = 1
if self.use_sigmoid_cls:
self.out_channels = out_channels
else:
self.out_channels = out_channels + 1
mlps_layers = []
cin = in_channels
for mlp in mlp_channels:
mlps_layers.extend([
nn.Linear(cin, mlp, bias=False),
build_norm_layer(norm_cfg, mlp)[1],
nn.ReLU()
])
cin = mlp
mlps_layers.append(nn.Linear(cin, self.out_channels, bias=True))
self.seg_cls_layer = nn.Sequential(*mlps_layers)
self.loss_seg = build_loss(loss_seg)
def forward(self, feats: torch.Tensor) -> dict:
"""Forward head.
Args:
feats (torch.Tensor): Point-wise features.
Returns:
dict: Segment predictions.
"""
seg_preds = self.seg_cls_layer(feats)
return dict(seg_preds=seg_preds)
def _get_targets_single(self, point_xyz: torch.Tensor,
gt_bboxes_3d: InstanceData,
gt_labels_3d: torch.Tensor) -> torch.Tensor:
"""generate segmentation targets for a single sample.
Args:
point_xyz (torch.Tensor): Coordinate of points.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth boxes in
shape (box_num, 7).
gt_labels_3d (torch.Tensor): Class labels of ground truths in
shape (box_num).
Returns:
torch.Tensor: Points class labels.
"""
point_cls_labels_single = point_xyz.new_zeros(
point_xyz.shape[0]).long()
enlarged_gt_boxes = gt_bboxes_3d.enlarged_box(self.extra_width)
box_idxs_of_pts = gt_bboxes_3d.points_in_boxes_part(point_xyz).long()
extend_box_idxs_of_pts = enlarged_gt_boxes.points_in_boxes_part(
point_xyz).long()
box_fg_flag = box_idxs_of_pts >= 0
fg_flag = box_fg_flag.clone()
ignore_flag = fg_flag ^ (extend_box_idxs_of_pts >= 0)
point_cls_labels_single[ignore_flag] = -1
gt_box_of_fg_points = gt_labels_3d[box_idxs_of_pts[fg_flag]]
point_cls_labels_single[
fg_flag] = 1 if self.num_classes == 1 else\
gt_box_of_fg_points.long()
return point_cls_labels_single,
def get_targets(self, points_bxyz: torch.Tensor,
batch_gt_instances_3d: InstanceList) -> dict:
"""Generate segmentation targets.
Args:
points_bxyz (torch.Tensor): The coordinates of point in shape
(B, num_points, 3).
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
Returns:
dict: Prediction targets
- seg_targets (torch.Tensor): Segmentation targets.
"""
batch_size = len(batch_gt_instances_3d)
points_xyz_list = []
gt_bboxes_3d = []
gt_labels_3d = []
for idx in range(batch_size):
coords_idx = points_bxyz[:, 0] == idx
points_xyz_list.append(points_bxyz[coords_idx][..., 1:])
gt_bboxes_3d.append(batch_gt_instances_3d[idx].bboxes_3d)
gt_labels_3d.append(batch_gt_instances_3d[idx].labels_3d)
seg_targets, = multi_apply(self._get_targets_single, points_xyz_list,
gt_bboxes_3d, gt_labels_3d)
seg_targets = torch.cat(seg_targets, dim=0)
return dict(seg_targets=seg_targets)
def loss(self, semantic_results: dict,
semantic_targets: dict) -> Dict[str, torch.Tensor]:
"""Calculate point-wise segmentation losses.
Args:
semantic_results (dict): Results from semantic head.
semantic_targets (dict): Targets of semantic results.
Returns:
dict: Loss of segmentation.
- loss_semantic (torch.Tensor): Segmentation prediction loss.
"""
seg_preds = semantic_results['seg_preds']
seg_targets = semantic_targets['seg_targets']
positives = (seg_targets > 0)
negative_cls_weights = (seg_targets == 0).float()
seg_weights = (negative_cls_weights + 1.0 * positives).float()
pos_normalizer = positives.sum(dim=0).float()
seg_weights /= torch.clamp(pos_normalizer, min=1.0)
seg_preds = torch.sigmoid(seg_preds)
loss_seg = self.loss_seg(seg_preds, (~positives).long(), seg_weights)
return dict(loss_semantic=loss_seg)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import torch
from torch.nn import functional as F
from mmdet3d.models.roi_heads.base_3droi_head import Base3DRoIHead
from mmdet3d.registry import MODELS
from mmdet3d.structures import bbox3d2roi
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import InstanceList
from mmdet.models.task_modules import AssignResult
from mmdet.models.task_modules.samplers import SamplingResult
@MODELS.register_module()
class PVRCNNRoiHead(Base3DRoIHead):
"""RoI head for PV-RCNN.
Args:
num_classes (int): The number of classes. Defaults to 3.
semantic_head (dict, optional): Config of semantic head.
Defaults to None.
bbox_roi_extractor (dict, optional): Config of roi_extractor.
Defaults to None.
bbox_head (dict, optional): Config of bbox_head. Defaults to None.
train_cfg (dict, optional): Train config of model.
Defaults to None.
test_cfg (dict, optional): Train config of model.
Defaults to None.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
"""
def __init__(self,
num_classes: int = 3,
semantic_head: Optional[dict] = None,
bbox_roi_extractor: Optional[dict] = None,
bbox_head: Optional[dict] = None,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None):
super(PVRCNNRoiHead, self).__init__(
bbox_head=bbox_head,
bbox_roi_extractor=bbox_roi_extractor,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg)
self.num_classes = num_classes
self.semantic_head = MODELS.build(semantic_head)
self.init_assigner_sampler()
@property
def with_semantic(self):
"""bool: whether the head has semantic branch"""
return hasattr(self,
'semantic_head') and self.semantic_head is not None
def loss(self, feats_dict: dict, rpn_results_list: InstanceList,
batch_data_samples: SampleList, **kwargs) -> dict:
"""Training forward function of PVRCNNROIHead.
Args:
feats_dict (dict): Contains point-wise features.
rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
dict: losses from each head.
- loss_semantic (torch.Tensor): loss of semantic head.
- loss_bbox (torch.Tensor): loss of bboxes.
- loss_cls (torch.Tensor): loss of object classification.
- loss_corner (torch.Tensor): loss of bboxes corners.
"""
losses = dict()
batch_gt_instances_3d = []
batch_gt_instances_ignore = []
for data_sample in batch_data_samples:
batch_gt_instances_3d.append(data_sample.gt_instances_3d)
if 'ignored_instances' in data_sample:
batch_gt_instances_ignore.append(data_sample.ignored_instances)
else:
batch_gt_instances_ignore.append(None)
if self.with_semantic:
semantic_results = self._semantic_forward_train(
feats_dict['keypoint_features'], feats_dict['keypoints'],
batch_gt_instances_3d)
losses['loss_semantic'] = semantic_results['loss_semantic']
sample_results = self._assign_and_sample(rpn_results_list,
batch_gt_instances_3d)
if self.with_bbox:
bbox_results = self._bbox_forward_train(
semantic_results['seg_preds'],
feats_dict['fusion_keypoint_features'],
feats_dict['keypoints'], sample_results)
losses.update(bbox_results['loss_bbox'])
return losses
def predict(self, feats_dict: dict, rpn_results_list: InstanceList,
batch_data_samples: SampleList, **kwargs) -> SampleList:
"""Perform forward propagation of the roi head and predict detection
results on the features of the upstream network.
Args:
feats_dict (dict): Contains point-wise features.
rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
list[:obj:`InstanceData`]: Detection results of each sample
after the post process.
Each item usually contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes,
contains a tensor with shape (num_instances, C), where
C >= 7.
"""
assert self.with_bbox, 'Bbox head must be implemented.'
assert self.with_semantic, 'Semantic head must be implemented.'
batch_input_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
semantic_results = self.semantic_head(feats_dict['keypoint_features'])
point_features = feats_dict[
'fusion_keypoint_features'] * semantic_results[
'seg_preds'].sigmoid().max(
dim=-1, keepdim=True).values
rois = bbox3d2roi(
[res['bboxes_3d'].tensor for res in rpn_results_list])
labels_3d = [res['labels_3d'] for res in rpn_results_list]
bbox_results = self._bbox_forward(point_features,
feats_dict['keypoints'], rois)
results_list = self.bbox_head.get_results(rois,
bbox_results['bbox_scores'],
bbox_results['bbox_reg'],
labels_3d, batch_input_metas,
self.test_cfg)
return results_list
def _bbox_forward_train(self, seg_preds: torch.Tensor,
keypoint_features: torch.Tensor,
keypoints: torch.Tensor,
sampling_results: SamplingResult) -> dict:
"""Forward training function of roi_extractor and bbox_head.
Args:
seg_preds (torch.Tensor): Point-wise semantic features.
keypoint_features (torch.Tensor): key points features
from points encoder.
keypoints (torch.Tensor): Coordinate of key points.
sampling_results (:obj:`SamplingResult`): Sampled results used
for training.
Returns:
dict: Forward results including losses and predictions.
"""
rois = bbox3d2roi([res.bboxes for res in sampling_results])
keypoint_features = keypoint_features * seg_preds.sigmoid().max(
dim=-1, keepdim=True).values
bbox_results = self._bbox_forward(keypoint_features, keypoints, rois)
bbox_targets = self.bbox_head.get_targets(sampling_results,
self.train_cfg)
loss_bbox = self.bbox_head.loss(bbox_results['bbox_scores'],
bbox_results['bbox_reg'], rois,
*bbox_targets)
bbox_results.update(loss_bbox=loss_bbox)
return bbox_results
def _bbox_forward(self, keypoint_features: torch.Tensor,
keypoints: torch.Tensor, rois: torch.Tensor) -> dict:
"""Forward function of roi_extractor and bbox_head used in both
training and testing.
Args:
rois (Tensor): Roi boxes.
keypoint_features (torch.Tensor): key points features
from points encoder.
keypoints (torch.Tensor): Coordinate of key points.
rois (Tensor): Roi boxes.
Returns:
dict: Contains predictions of bbox_head and
features of roi_extractor.
"""
pooled_keypoint_features = self.bbox_roi_extractor(
keypoint_features, keypoints[..., 1:], keypoints[..., 0].int(),
rois)
bbox_score, bbox_reg = self.bbox_head(pooled_keypoint_features)
bbox_results = dict(bbox_scores=bbox_score, bbox_reg=bbox_reg)
return bbox_results
def _assign_and_sample(
self, proposal_list: InstanceList,
batch_gt_instances_3d: InstanceList) -> List[SamplingResult]:
"""Assign and sample proposals for training.
Args:
proposal_list (list[:obj:`InstancesData`]): Proposals produced by
rpn head.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
Returns:
list[:obj:`SamplingResult`]: Sampled results of each training
sample.
"""
sampling_results = []
# bbox assign
for batch_idx in range(len(proposal_list)):
cur_proposal_list = proposal_list[batch_idx]
cur_boxes = cur_proposal_list['bboxes_3d']
cur_labels_3d = cur_proposal_list['labels_3d']
cur_gt_instances_3d = batch_gt_instances_3d[batch_idx]
cur_gt_instances_3d.bboxes_3d = cur_gt_instances_3d.\
bboxes_3d.tensor
cur_gt_bboxes = batch_gt_instances_3d[batch_idx].bboxes_3d.to(
cur_boxes.device)
cur_gt_labels = batch_gt_instances_3d[batch_idx].labels_3d
batch_num_gts = 0
# 0 is bg
batch_gt_indis = cur_gt_labels.new_full((len(cur_boxes), ), 0)
batch_max_overlaps = cur_boxes.tensor.new_zeros(len(cur_boxes))
# -1 is bg
batch_gt_labels = cur_gt_labels.new_full((len(cur_boxes), ), -1)
# each class may have its own assigner
if isinstance(self.bbox_assigner, list):
for i, assigner in enumerate(self.bbox_assigner):
gt_per_cls = (cur_gt_labels == i)
pred_per_cls = (cur_labels_3d == i)
cur_assign_res = assigner.assign(
cur_proposal_list[pred_per_cls],
cur_gt_instances_3d[gt_per_cls])
# gather assign_results in different class into one result
batch_num_gts += cur_assign_res.num_gts
# gt inds (1-based)
gt_inds_arange_pad = gt_per_cls.nonzero(
as_tuple=False).view(-1) + 1
# pad 0 for indice unassigned
gt_inds_arange_pad = F.pad(
gt_inds_arange_pad, (1, 0), mode='constant', value=0)
# pad -1 for indice ignore
gt_inds_arange_pad = F.pad(
gt_inds_arange_pad, (1, 0), mode='constant', value=-1)
# convert to 0~gt_num+2 for indices
gt_inds_arange_pad += 1
# now 0 is bg, >1 is fg in batch_gt_indis
batch_gt_indis[pred_per_cls] = gt_inds_arange_pad[
cur_assign_res.gt_inds + 1] - 1
batch_max_overlaps[
pred_per_cls] = cur_assign_res.max_overlaps
batch_gt_labels[pred_per_cls] = cur_assign_res.labels
assign_result = AssignResult(batch_num_gts, batch_gt_indis,
batch_max_overlaps,
batch_gt_labels)
else: # for single class
assign_result = self.bbox_assigner.assign(
cur_proposal_list, cur_gt_instances_3d)
# sample boxes
sampling_result = self.bbox_sampler.sample(assign_result,
cur_boxes.tensor,
cur_gt_bboxes,
cur_gt_labels)
sampling_results.append(sampling_result)
return sampling_results
def _semantic_forward_train(self, keypoint_features: torch.Tensor,
keypoints: torch.Tensor,
batch_gt_instances_3d: InstanceList) -> dict:
"""Train semantic head.
Args:
keypoint_features (torch.Tensor): key points features
from points encoder.
keypoints (torch.Tensor): Coordinate of key points.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
Returns:
dict: Segmentation results including losses
"""
semantic_results = self.semantic_head(keypoint_features)
semantic_targets = self.semantic_head.get_targets(
keypoints, batch_gt_instances_3d)
loss_semantic = self.semantic_head.loss(semantic_results,
semantic_targets)
semantic_results.update(loss_semantic)
return semantic_results
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models.roi_heads.roi_extractors import SingleRoIExtractor from mmdet.models.roi_heads.roi_extractors import SingleRoIExtractor
from .batch_roigridpoint_extractor import Batch3DRoIGridExtractor
from .single_roiaware_extractor import Single3DRoIAwareExtractor from .single_roiaware_extractor import Single3DRoIAwareExtractor
from .single_roipoint_extractor import Single3DRoIPointExtractor from .single_roipoint_extractor import Single3DRoIPointExtractor
__all__ = [ __all__ = [
'SingleRoIExtractor', 'Single3DRoIAwareExtractor', 'SingleRoIExtractor', 'Single3DRoIAwareExtractor',
'Single3DRoIPointExtractor' 'Single3DRoIPointExtractor', 'Batch3DRoIGridExtractor'
] ]
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.model import BaseModule
from mmdet3d.registry import MODELS
from mmdet3d.structures.bbox_3d import rotation_3d_in_axis
@MODELS.register_module()
class Batch3DRoIGridExtractor(BaseModule):
"""Grid point wise roi-aware Extractor.
Args:
grid_size (int): The number of grid points in a roi bbox.
Defaults to 6.
roi_layer (dict, optional): Config of sa module to get
grid points features. Defaults to None.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
"""
def __init__(self,
grid_size: int = 6,
roi_layer: dict = None,
init_cfg: dict = None) -> None:
super(Batch3DRoIGridExtractor, self).__init__(init_cfg=init_cfg)
self.roi_grid_pool_layer = MODELS.build(roi_layer)
self.grid_size = grid_size
def forward(self, feats: torch.Tensor, coordinate: torch.Tensor,
batch_inds: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
"""Forward roi extractor to extract grid points feature.
Args:
feats (torch.Tensor): Key points features.
coordinate (torch.Tensor): Key points coordinates.
batch_inds (torch.Tensor): Input batch indexes.
rois (torch.Tensor): Detection results of rpn head.
Returns:
torch.Tensor: Grid points features.
"""
batch_size = int(batch_inds.max()) + 1
xyz = coordinate
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for k in range(batch_size):
xyz_batch_cnt[k] = (batch_inds == k).sum()
rois_batch_inds = rois[:, 0].int()
# (N1+N2+..., 6x6x6, 3)
roi_grid = self.get_dense_grid_points(rois[:, 1:])
new_xyz = roi_grid.view(-1, 3)
new_xyz_batch_cnt = new_xyz.new_zeros(batch_size).int()
for k in range(batch_size):
new_xyz_batch_cnt[k] = ((rois_batch_inds == k).sum() *
roi_grid.size(1))
pooled_points, pooled_features = self.roi_grid_pool_layer(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz.contiguous(),
new_xyz_batch_cnt=new_xyz_batch_cnt,
features=feats.contiguous()) # (M1 + M2 ..., C)
pooled_features = pooled_features.view(-1, self.grid_size,
self.grid_size, self.grid_size,
pooled_features.shape[-1])
# (BxN, 6, 6, 6, C)
return pooled_features
def get_dense_grid_points(self, rois: torch.Tensor) -> torch.Tensor:
"""Get dense grid points from rois.
Args:
rois (torch.Tensor): Detection results of rpn head.
Returns:
torch.Tensor: Grid points coordinates.
"""
rois_bbox = rois.clone()
rois_bbox[:, 2] += rois_bbox[:, 5] / 2
faked_features = rois_bbox.new_ones(
(self.grid_size, self.grid_size, self.grid_size))
dense_idx = faked_features.nonzero()
dense_idx = dense_idx.repeat(rois_bbox.size(0), 1, 1).float()
dense_idx = ((dense_idx + 0.5) / self.grid_size)
dense_idx[..., :3] -= 0.5
roi_ctr = rois_bbox[:, :3]
roi_dim = rois_bbox[:, 3:6]
roi_grid_points = dense_idx * roi_dim.view(-1, 1, 3)
roi_grid_points = rotation_3d_in_axis(
roi_grid_points, rois_bbox[:, 6], axis=2)
roi_grid_points += roi_ctr.view(-1, 1, 3)
return roi_grid_points
import unittest
import torch
from mmengine import DefaultScope
from mmdet3d.registry import MODELS
from tests.utils.model_utils import (_create_detector_inputs,
_get_detector_cfg, _setup_seed)
class TestPVRCNN(unittest.TestCase):
def test_pvrcnn(self):
import mmdet3d.models
assert hasattr(mmdet3d.models, 'PointVoxelRCNN')
DefaultScope.get_instance('test_pvrcnn', scope_name='mmdet3d')
_setup_seed(0)
pvrcnn_cfg = _get_detector_cfg(
'pvrcnn/pvrcnn_8xb2-80e_kitti-3d-3class.py')
model = MODELS.build(pvrcnn_cfg)
num_gt_instance = 2
packed_inputs = _create_detector_inputs(
num_gt_instance=num_gt_instance)
# TODO: Support aug data test
# aug_packed_inputs = [
# _create_detector_inputs(num_gt_instance=num_gt_instance),
# _create_detector_inputs(num_gt_instance=num_gt_instance + 1)
# ]
# test_aug_test
# metainfo = {
# 'pcd_scale_factor': 1,
# 'pcd_horizontal_flip': 1,
# 'pcd_vertical_flip': 1,
# 'box_type_3d': LiDARInstance3DBoxes
# }
# for item in aug_packed_inputs:
# for batch_id in len(item['data_samples']):
# item['data_samples'][batch_id].set_metainfo(metainfo)
if torch.cuda.is_available():
model = model.cuda()
# test simple_test
with torch.no_grad():
data = model.data_preprocessor(packed_inputs, True)
torch.cuda.empty_cache()
results = model.forward(**data, mode='predict')
self.assertEqual(len(results), 1)
self.assertIn('bboxes_3d', results[0].pred_instances_3d)
self.assertIn('scores_3d', results[0].pred_instances_3d)
self.assertIn('labels_3d', results[0].pred_instances_3d)
# save the memory
with torch.no_grad():
losses = model.forward(**data, mode='loss')
torch.cuda.empty_cache()
self.assertGreater(losses['loss_rpn_cls'][0], 0)
self.assertGreaterEqual(losses['loss_rpn_bbox'][0], 0)
self.assertGreaterEqual(losses['loss_rpn_dir'][0], 0)
self.assertGreater(losses['loss_semantic'], 0)
self.assertGreaterEqual(losses['loss_bbox'], 0)
self.assertGreaterEqual(losses['loss_cls'], 0)
self.assertGreaterEqual(losses['loss_corner'], 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment