Commit aa544d4a authored by VVsssssk's avatar VVsssssk Committed by ZwwWayne
Browse files

[Features] Support PV_RCNN modules (#1957)

* add pvrcnn module code

* add voxelsa

* fix

* fix comments

* fix comments

* fix comments

* add stack sa

* fix

* fix comments

* fix comments

* fix

* add ut

* fix comments
parent 1fd71531
_base_ = [
'../_base_/datasets/kitti-3d-3class.py',
'../_base_/schedules/cyclic-40e.py', '../_base_/default_runtime.py'
]
voxel_size = [0.05, 0.05, 0.1]
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
data_root = 'data/kitti/'
class_names = ['Pedestrian', 'Cyclist', 'Car']
metainfo = dict(CLASSES=class_names)
db_sampler = dict(
data_root=data_root,
info_path=data_root + 'kitti_dbinfos_train.pkl',
rate=1.0,
prepare=dict(
filter_by_difficulty=[-1],
filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)),
classes=class_names,
sample_groups=dict(Car=15, Pedestrian=10, Cyclist=10),
points_loader=dict(
type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4))
train_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler, use_ground_plane=True),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05]),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'),
dict(
type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
model = dict(
type='PointVoxelRCNN',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
voxel_layer=dict(
max_num_points=5, # max_points_per_voxel
point_cloud_range=point_cloud_range,
voxel_size=voxel_size,
max_voxels=(16000, 40000))),
voxel_encoder=dict(type='HardSimpleVFE'),
middle_encoder=dict(
type='SparseEncoder',
in_channels=4,
sparse_shape=[41, 1600, 1408],
order=('conv', 'norm', 'act'),
encoder_paddings=((0, 0, 0), ((1, 1, 1), 0, 0), ((1, 1, 1), 0, 0),
((0, 1, 1), 0, 0)),
return_middle_feats=True),
points_encoder=dict(
type='VoxelSetAbstraction',
num_keypoints=2048,
fused_out_channel=128,
voxel_size=voxel_size,
point_cloud_range=point_cloud_range,
voxel_sa_cfgs_list=[
dict(
type='StackedSAModuleMSG',
in_channels=16,
scale_factor=1,
radius=(0.4, 0.8),
sample_nums=(16, 16),
mlp_channels=((16, 16), (16, 16)),
use_xyz=True),
dict(
type='StackedSAModuleMSG',
in_channels=32,
scale_factor=2,
radius=(0.8, 1.2),
sample_nums=(16, 32),
mlp_channels=((32, 32), (32, 32)),
use_xyz=True),
dict(
type='StackedSAModuleMSG',
in_channels=64,
scale_factor=4,
radius=(1.2, 2.4),
sample_nums=(16, 32),
mlp_channels=((64, 64), (64, 64)),
use_xyz=True),
dict(
type='StackedSAModuleMSG',
in_channels=64,
scale_factor=8,
radius=(2.4, 4.8),
sample_nums=(16, 32),
mlp_channels=((64, 64), (64, 64)),
use_xyz=True)
],
rawpoints_sa_cfgs=dict(
type='StackedSAModuleMSG',
in_channels=1,
radius=(0.4, 0.8),
sample_nums=(16, 16),
mlp_channels=((16, 16), (16, 16)),
use_xyz=True),
bev_feat_channel=256,
bev_scale_factor=8),
backbone=dict(
type='SECOND',
in_channels=256,
layer_nums=[5, 5],
layer_strides=[1, 2],
out_channels=[128, 256]),
neck=dict(
type='SECONDFPN',
in_channels=[128, 256],
upsample_strides=[1, 2],
out_channels=[256, 256]),
rpn_head=dict(
type='PartA2RPNHead',
num_classes=3,
in_channels=512,
feat_channels=512,
use_direction_classifier=True,
dir_offset=0.78539,
anchor_generator=dict(
type='Anchor3DRangeGenerator',
ranges=[[0, -40.0, -0.6, 70.4, 40.0, -0.6],
[0, -40.0, -0.6, 70.4, 40.0, -0.6],
[0, -40.0, -1.78, 70.4, 40.0, -1.78]],
sizes=[[0.8, 0.6, 1.73], [1.76, 0.6, 1.73], [3.9, 1.6, 1.56]],
rotations=[0, 1.57],
reshape_out=False),
diff_rad_by_sin=True,
assigner_per_size=True,
assign_per_class=True,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_cls=dict(
type='mmdet.FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(
type='mmdet.SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_dir=dict(
type='mmdet.CrossEntropyLoss', use_sigmoid=False,
loss_weight=0.2)),
roi_head=dict(
type='PVRCNNRoiHead',
num_classes=3,
semantic_head=dict(
type='ForegroundSegmentationHead',
in_channels=640,
extra_width=0.1,
loss_seg=dict(
type='mmdet.FocalLoss',
use_sigmoid=True,
reduction='sum',
gamma=2.0,
alpha=0.25,
activated=True,
loss_weight=1.0)),
bbox_roi_extractor=dict(
type='Batch3DRoIGridExtractor',
grid_size=6,
roi_layer=dict(
type='StackedSAModuleMSG',
in_channels=128,
radius=(0.8, 1.6),
sample_nums=(16, 16),
mlp_channels=((64, 64), (64, 64)),
use_xyz=True,
pool_mod='max'),
),
bbox_head=dict(
type='PVRCNNBBoxHead',
in_channels=128,
grid_size=6,
num_classes=3,
class_agnostic=True,
shared_fc_channels=(256, 256),
reg_channels=(256, 256),
cls_channels=(256, 256),
dropout_ratio=0.3,
with_corner_loss=True,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_bbox=dict(
type='mmdet.SmoothL1Loss',
beta=1.0 / 9.0,
reduction='sum',
loss_weight=1.0),
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=[
dict( # for Pedestrian
type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.5,
neg_iou_thr=0.35,
min_pos_iou=0.35,
ignore_iof_thr=-1),
dict( # for Cyclist
type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.5,
neg_iou_thr=0.35,
min_pos_iou=0.35,
ignore_iof_thr=-1),
dict( # for Car
type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.6,
neg_iou_thr=0.45,
min_pos_iou=0.45,
ignore_iof_thr=-1)
],
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=9000,
nms_post=512,
max_num=512,
nms_thr=0.8,
score_thr=0,
use_rotate_nms=True),
rcnn=dict(
assigner=[
dict( # for Pedestrian
type='Max3DIoUAssigner',
iou_calculator=dict(
type='BboxOverlaps3D', coordinate='lidar'),
pos_iou_thr=0.55,
neg_iou_thr=0.55,
min_pos_iou=0.55,
ignore_iof_thr=-1),
dict( # for Cyclist
type='Max3DIoUAssigner',
iou_calculator=dict(
type='BboxOverlaps3D', coordinate='lidar'),
pos_iou_thr=0.55,
neg_iou_thr=0.55,
min_pos_iou=0.55,
ignore_iof_thr=-1),
dict( # for Car
type='Max3DIoUAssigner',
iou_calculator=dict(
type='BboxOverlaps3D', coordinate='lidar'),
pos_iou_thr=0.55,
neg_iou_thr=0.55,
min_pos_iou=0.55,
ignore_iof_thr=-1)
],
sampler=dict(
type='IoUNegPiecewiseSampler',
num=128,
pos_fraction=0.5,
neg_piece_fractions=[0.8, 0.2],
neg_iou_piece_thrs=[0.55, 0.1],
neg_pos_ub=-1,
add_gt_as_proposals=False,
return_iou=True),
cls_pos_thr=0.75,
cls_neg_thr=0.25)),
test_cfg=dict(
rpn=dict(
nms_pre=1024,
nms_post=100,
max_num=100,
nms_thr=0.7,
score_thr=0,
use_rotate_nms=True),
rcnn=dict(
use_rotate_nms=True,
use_raw_score=True,
nms_thr=0.1,
score_thr=0.1)))
train_dataloader = dict(
batch_size=2,
num_workers=2,
dataset=dict(dataset=dict(pipeline=train_pipeline, metainfo=metainfo)))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo))
eval_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo))
lr = 0.001
optim_wrapper = dict(optimizer=dict(lr=lr))
param_scheduler = [
# learning rate scheduler
# During the first 16 epochs, learning rate increases from 0 to lr * 10
# during the next 24 epochs, learning rate decreases from lr * 10 to
# lr * 1e-4
dict(
type='CosineAnnealingLR',
T_max=15,
eta_min=lr * 10,
begin=0,
end=15,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=25,
eta_min=lr * 1e-4,
begin=15,
end=40,
by_epoch=True,
convert_to_iter_based=True),
# momentum scheduler
# During the first 16 epochs, momentum increases from 0 to 0.85 / 0.95
# during the next 24 epochs, momentum increases from 0.85 / 0.95 to 1
dict(
type='CosineAnnealingMomentum',
T_max=15,
eta_min=0.85 / 0.95,
begin=0,
end=15,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingMomentum',
T_max=25,
eta_min=1,
begin=15,
end=40,
by_epoch=True,
convert_to_iter_based=True)
]
......@@ -67,7 +67,6 @@ def init_model(config: Union[str, Path, Config],
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
# save the dataset_meta in the model for convenience
if 'dataset_meta' in checkpoint.get('meta', {}):
......
......@@ -14,6 +14,7 @@ from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN
from .mvx_two_stage import MVXTwoStageDetector
from .parta2 import PartA2
from .point_rcnn import PointRCNN
from .pv_rcnn import PointVoxelRCNN
from .sassd import SASSD
from .single_stage_mono3d import SingleStageMono3DDetector
from .smoke_mono3d import SMOKEMono3D
......@@ -26,5 +27,6 @@ __all__ = [
'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet',
'CenterPoint', 'SSD3DNet', 'ImVoteNet', 'SingleStageMono3DDetector',
'FCOSMono3D', 'ImVoxelNet', 'GroupFree3DNet', 'PointRCNN', 'SMOKEMono3D',
'SASSD', 'MinkSingleStage3DDetector', 'MultiViewDfM', 'DfM'
'SASSD', 'MinkSingleStage3DDetector', 'MultiViewDfM', 'DfM',
'PointVoxelRCNN'
]
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Optional
from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import InstanceList
from .two_stage import TwoStage3DDetector
@MODELS.register_module()
class PointVoxelRCNN(TwoStage3DDetector):
r"""PointVoxelRCNN detector.
Please refer to the `PointVoxelRCNN <https://arxiv.org/abs/1912.13192>`_.
Args:
voxel_encoder (dict): Point voxelization encoder layer.
middle_encoder (dict): Middle encoder layer
of points cloud modality.
backbone (dict): Backbone of extracting points features.
neck (dict, optional): Neck of extracting points features.
Defaults to None.
rpn_head (dict, optional): Config of RPN head. Defaults to None.
points_encoder (dict, optional): Points encoder to extract point-wise
features. Defaults to None.
roi_head (dict, optional): Config of ROI head. Defaults to None.
train_cfg (dict, optional): Train config of model.
Defaults to None.
test_cfg (dict, optional): Train config of model.
Defaults to None.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`Det3DDataPreprocessor`. Defaults to None.
"""
def __init__(self,
voxel_encoder: dict,
middle_encoder: dict,
backbone: dict,
neck: Optional[dict] = None,
rpn_head: Optional[dict] = None,
points_encoder: Optional[dict] = None,
roi_head: Optional[dict] = None,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None,
data_preprocessor: Optional[dict] = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg,
data_preprocessor=data_preprocessor)
self.voxel_encoder = MODELS.build(voxel_encoder)
self.middle_encoder = MODELS.build(middle_encoder)
self.points_encoder = MODELS.build(points_encoder)
def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
**kwargs) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
list[:obj:`Det3DDataSample`]: Detection results of the
input samples. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7.
"""
feats_dict = self.extract_feat(batch_inputs_dict)
if self.with_rpn:
rpn_results_list = self.rpn_head.predict(feats_dict,
batch_data_samples)
else:
rpn_results_list = [
data_sample.proposals for data_sample in batch_data_samples
]
# extrack points feats by points_encoder
points_feats_dict = self.extract_points_feat(batch_inputs_dict,
feats_dict,
rpn_results_list)
results_list_3d = self.roi_head.predict(points_feats_dict,
rpn_results_list,
batch_data_samples)
# connvert to Det3DDataSample
results_list = self.add_pred_to_datasample(batch_data_samples,
results_list_3d)
return results_list
def extract_feat(self, batch_inputs_dict: dict) -> dict:
"""Extract features from the input voxels.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
Returns:
dict: We typically obtain a dict of features from the backbone +
neck, it includes:
- spatial_feats (torch.Tensor): Spatial feats from middle
encoder.
- multi_scale_3d_feats (list[torch.Tensor]): Multi scale
middle feats from middle encoder.
- neck_feats (torch.Tensor): Neck feats from neck.
"""
feats_dict = dict()
voxel_dict = batch_inputs_dict['voxels']
voxel_features = self.voxel_encoder(voxel_dict['voxels'],
voxel_dict['num_points'],
voxel_dict['coors'])
batch_size = voxel_dict['coors'][-1, 0].item() + 1
feats_dict['spatial_feats'], feats_dict[
'multi_scale_3d_feats'] = self.middle_encoder(
voxel_features, voxel_dict['coors'], batch_size)
x = self.backbone(feats_dict['spatial_feats'])
if self.with_neck:
neck_feats = self.neck(x)
feats_dict['neck_feats'] = neck_feats
return feats_dict
def extract_points_feat(self, batch_inputs_dict: dict, feats_dict: dict,
rpn_results_list: InstanceList) -> dict:
"""Extract point-wise features from the raw points and voxel features.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
feats_dict (dict): Contains features from the first stage.
rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head.
Returns:
dict: Contain Point-wise features, include:
- keypoints (torch.Tensor): Sampled key points.
- keypoint_features (torch.Tensor): Gather key points features
from multi input.
- fusion_keypoint_features (torch.Tensor): Fusion
keypoint_features by point_feature_fusion_layer.
"""
return self.points_encoder(batch_inputs_dict, feats_dict,
rpn_results_list)
def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
**kwargs):
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
dict: A dictionary of loss components.
"""
feats_dict = self.extract_feat(batch_inputs_dict)
losses = dict()
# RPN forward and loss
if self.with_rpn:
proposal_cfg = self.train_cfg.get('rpn_proposal',
self.test_cfg.rpn)
rpn_data_samples = copy.deepcopy(batch_data_samples)
rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict(
feats_dict,
rpn_data_samples,
proposal_cfg=proposal_cfg,
**kwargs)
# avoid get same name with roi_head loss
keys = rpn_losses.keys()
for key in keys:
if 'loss' in key and 'rpn' not in key:
rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key)
losses.update(rpn_losses)
else:
# TODO: Not support currently, should have a check at Fast R-CNN
assert batch_data_samples[0].get('proposals', None) is not None
# use pre-defined proposals in InstanceData for the second stage
# to extract ROI features.
rpn_results_list = [
data_sample.proposals for data_sample in batch_data_samples
]
points_feats_dict = self.extract_points_feat(batch_inputs_dict,
feats_dict,
rpn_results_list)
roi_losses = self.roi_head.loss(points_feats_dict, rpn_results_list,
batch_data_samples)
losses.update(roi_losses)
return losses
......@@ -4,9 +4,10 @@ from .paconv_sa_module import (PAConvCUDASAModule, PAConvCUDASAModuleMSG,
PAConvSAModule, PAConvSAModuleMSG)
from .point_fp_module import PointFPModule
from .point_sa_module import PointSAModule, PointSAModuleMSG
from .stack_point_sa_module import StackedSAModuleMSG
__all__ = [
'build_sa_module', 'PointSAModuleMSG', 'PointSAModule', 'PointFPModule',
'PAConvSAModule', 'PAConvSAModuleMSG', 'PAConvCUDASAModule',
'PAConvCUDASAModuleMSG'
'PAConvCUDASAModuleMSG', 'StackedSAModuleMSG'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.ops import ball_query, grouping_operation
from mmengine.model import BaseModule
from torch import Tensor
from mmdet3d.registry import MODELS
class StackQueryAndGroup(BaseModule):
"""Find nearby points in spherical space.
Args:
radius (float): List of radius in each ball query.
sample_nums (int): Number of samples in each ball query.
use_xyz (bool): Whether to use xyz. Default: True.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
"""
def __init__(self,
radius: float,
sample_nums: int,
use_xyz: bool = True,
init_cfg: dict = None):
super().__init__(init_cfg=init_cfg)
self.radius, self.sample_nums, self.use_xyz = \
radius, sample_nums, use_xyz
def forward(self,
xyz: torch.Tensor,
xyz_batch_cnt: torch.Tensor,
new_xyz: torch.Tensor,
new_xyz_batch_cnt: torch.Tensor,
features: torch.Tensor = None) -> Tuple[Tensor, Tensor]:
"""Forward.
Args:
xyz (Tensor): Tensor of the xyz coordinates
of the features shape with (N1 + N2 ..., 3).
xyz_batch_cnt: (Tensor): Stacked input xyz coordinates nums in
each batch, just like (N1, N2, ...).
new_xyz (Tensor): New coords of the outputs shape with
(M1 + M2 ..., 3).
new_xyz_batch_cnt: (Tensor): Stacked new xyz coordinates nums
in each batch, just like (M1, M2, ...).
features (Tensor, optional): Features of each point with shape
(N1 + N2 ..., C). C is features channel number. Default: None.
"""
assert xyz.shape[0] == xyz_batch_cnt.sum(
), f'xyz: {str(xyz.shape)}, xyz_batch_cnt: str(new_xyz_batch_cnt)'
assert new_xyz.shape[0] == new_xyz_batch_cnt.sum(), \
'new_xyz: str(new_xyz.shape), new_xyz_batch_cnt: ' \
'str(new_xyz_batch_cnt)'
# idx: (M1 + M2 ..., nsample), empty_ball_mask: (M1 + M2 ...)
idx, empty_ball_mask = ball_query(0, self.radius, self.sample_nums,
xyz, new_xyz, xyz_batch_cnt,
new_xyz_batch_cnt)
grouped_xyz = grouping_operation(
xyz, idx, xyz_batch_cnt,
new_xyz_batch_cnt) # (M1 + M2, 3, nsample)
grouped_xyz -= new_xyz.unsqueeze(-1)
grouped_xyz[empty_ball_mask] = 0
if features is not None:
grouped_features = grouping_operation(
features, idx, xyz_batch_cnt,
new_xyz_batch_cnt) # (M1 + M2, C, nsample)
grouped_features[empty_ball_mask] = 0
if self.use_xyz:
new_features = torch.cat(
[grouped_xyz, grouped_features],
dim=1) # (M1 + M2 ..., C + 3, nsample)
else:
new_features = grouped_features
else:
assert self.use_xyz, 'Cannot have not features and not' \
' use xyz as a feature!'
new_features = grouped_xyz
return new_features, idx
@MODELS.register_module()
class StackedSAModuleMSG(BaseModule):
"""Stack point set abstraction module.
Args:
in_channels (int): Input channels.
radius (list[float]): List of radius in each ball query.
sample_nums (list[int]): Number of samples in each ball query.
mlp_channels (list[list[int]]): Specify mlp channels of the
pointnet before the global pooling for each scale to encode
point features.
use_xyz (bool): Whether to use xyz. Default: True.
pool_mod (str): Type of pooling method.
Default: 'max_pool'.
norm_cfg (dict): Type of normalization method. Defaults to
dict(type='BN2d', eps=1e-5, momentum=0.01).
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
"""
def __init__(self,
in_channels: int,
radius: List[float],
sample_nums: List[int],
mlp_channels: List[List[int]],
use_xyz: bool = True,
pool_mod='max',
norm_cfg: dict = dict(type='BN2d', eps=1e-5, momentum=0.01),
init_cfg: dict = None,
**kwargs) -> None:
super(StackedSAModuleMSG, self).__init__(init_cfg=init_cfg)
assert len(radius) == len(sample_nums) == len(mlp_channels)
self.groupers = nn.ModuleList()
self.mlps = nn.ModuleList()
for i in range(len(radius)):
cin = in_channels
if use_xyz:
cin += 3
cur_radius = radius[i]
nsample = sample_nums[i]
mlp_spec = mlp_channels[i]
self.groupers.append(
StackQueryAndGroup(cur_radius, nsample, use_xyz=use_xyz))
mlp = nn.Sequential()
for i in range(len(mlp_spec)):
cout = mlp_spec[i]
mlp.add_module(
f'layer{i}',
ConvModule(
cin,
cout,
kernel_size=(1, 1),
stride=(1, 1),
conv_cfg=dict(type='Conv2d'),
norm_cfg=norm_cfg,
bias=False))
cin = cout
self.mlps.append(mlp)
self.pool_mod = pool_mod
def forward(self,
xyz: Tensor,
xyz_batch_cnt: Tensor,
new_xyz: Tensor,
new_xyz_batch_cnt: Tensor,
features: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""Forward.
Args:
xyz (Tensor): Tensor of the xyz coordinates
of the features shape with (N1 + N2 ..., 3).
xyz_batch_cnt: (Tensor): Stacked input xyz coordinates nums in
each batch, just like (N1, N2, ...).
new_xyz (Tensor): New coords of the outputs shape with
(M1 + M2 ..., 3).
new_xyz_batch_cnt: (Tensor): Stacked new xyz coordinates nums
in each batch, just like (M1, M2, ...).
features (Tensor, optional): Features of each point with shape
(N1 + N2 ..., C). C is features channel number. Default: None.
Returns:
Return new points coordinates and features:
- new_xyz (Tensor): Target points coordinates with shape
(N1 + N2 ..., 3).
- new_features (Tensor): Target points features with shape
(M1 + M2 ..., sum_k(mlps[k][-1])).
"""
new_features_list = []
for k in range(len(self.groupers)):
grouped_features, ball_idxs = self.groupers[k](
xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt,
features) # (M1 + M2, Cin, nsample)
grouped_features = grouped_features.permute(1, 0,
2).unsqueeze(dim=0)
new_features = self.mlps[k](grouped_features)
# (M1 + M2 ..., Cout, nsample)
if self.pool_mod == 'max':
new_features = new_features.max(-1).values
elif self.pool_mod == 'avg':
new_features = new_features.mean(-1)
else:
raise NotImplementedError
new_features = new_features.squeeze(dim=0).permute(1, 0)
new_features_list.append(new_features)
new_features = torch.cat(new_features_list, dim=1)
return new_xyz, new_features
......@@ -2,7 +2,9 @@
from .pillar_scatter import PointPillarsScatter
from .sparse_encoder import SparseEncoder, SparseEncoderSASSD
from .sparse_unet import SparseUNet
from .voxel_set_abstraction import VoxelSetAbstraction
__all__ = [
'PointPillarsScatter', 'SparseEncoder', 'SparseEncoderSASSD', 'SparseUNet'
'PointPillarsScatter', 'SparseEncoder', 'SparseEncoderSASSD', 'SparseUNet',
'VoxelSetAbstraction'
]
......@@ -41,6 +41,8 @@ class SparseEncoder(nn.Module):
Defaults to ((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1)).
block_type (str, optional): Type of the block to use.
Defaults to 'conv_module'.
return_middle_feats (bool): Whether output middle features.
Default to False.
"""
def __init__(self,
......@@ -54,7 +56,8 @@ class SparseEncoder(nn.Module):
64)),
encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
1)),
block_type='conv_module'):
block_type='conv_module',
return_middle_feats=False):
super().__init__()
assert block_type in ['conv_module', 'basicblock']
self.sparse_shape = sparse_shape
......@@ -66,6 +69,7 @@ class SparseEncoder(nn.Module):
self.encoder_paddings = encoder_paddings
self.stage_num = len(self.encoder_channels)
self.fp16_enabled = False
self.return_middle_feats = return_middle_feats
# Spconv init all weight on its own
assert isinstance(order, tuple) and len(order) == 3
......@@ -117,7 +121,14 @@ class SparseEncoder(nn.Module):
batch_size (int): Batch size.
Returns:
dict: Backbone features.
torch.Tensor | tuple[torch.Tensor, list]: Return spatial features
include:
- spatial_features (torch.Tensor): Spatial features are out from
the last layer.
- encode_features (List[SparseConvTensor], optional): Middle layer
output features. When self.return_middle_feats is True, the
module returns middle features.
"""
coors = coors.int()
input_sp_tensor = SparseConvTensor(voxel_features, coors,
......@@ -137,7 +148,10 @@ class SparseEncoder(nn.Module):
N, C, D, H, W = spatial_features.shape
spatial_features = spatial_features.view(N, C * D, H, W)
return spatial_features
if self.return_middle_feats:
return spatial_features, encode_features
else:
return spatial_features
def make_encoder_layers(self,
make_block,
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import mmengine
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.ops.furthest_point_sample import furthest_point_sample
from mmengine.model import BaseModule
from mmdet3d.registry import MODELS
from mmdet3d.utils import InstanceList
def bilinear_interpolate_torch(inputs, x, y):
"""Bilinear interpolate for inputs."""
x0 = torch.floor(x).long()
x1 = x0 + 1
y0 = torch.floor(y).long()
y1 = y0 + 1
x0 = torch.clamp(x0, 0, inputs.shape[1] - 1)
x1 = torch.clamp(x1, 0, inputs.shape[1] - 1)
y0 = torch.clamp(y0, 0, inputs.shape[0] - 1)
y1 = torch.clamp(y1, 0, inputs.shape[0] - 1)
Ia = inputs[y0, x0]
Ib = inputs[y1, x0]
Ic = inputs[y0, x1]
Id = inputs[y1, x1]
wa = (x1.type_as(x) - x) * (y1.type_as(y) - y)
wb = (x1.type_as(x) - x) * (y - y0.type_as(y))
wc = (x - x0.type_as(x)) * (y1.type_as(y) - y)
wd = (x - x0.type_as(x)) * (y - y0.type_as(y))
ans = torch.t((torch.t(Ia) * wa)) + torch.t(torch.t(Ib) * wb) + torch.t(
torch.t(Ic) * wc) + torch.t(torch.t(Id) * wd)
return ans
@MODELS.register_module()
class VoxelSetAbstraction(BaseModule):
"""Voxel set abstraction module for PVRCNN and PVRCNN++.
Args:
num_keypoints (int): The number of key points sampled from
raw points cloud.
fused_out_channel (int): Key points feature output channels
num after fused. Default to 128.
voxel_size (list[float]): Size of voxels. Defaults to
[0.05, 0.05, 0.1].
point_cloud_range (list[float]): Point cloud range. Defaults to
[0, -40, -3, 70.4, 40, 1].
voxel_sa_cfgs_list (List[dict or ConfigDict], optional): List of SA
module cfg. Used to gather key points features from multi-wise
voxel features. Default to None.
rawpoints_sa_cfgs (dict or ConfigDict, optional): SA module cfg.
Used to gather key points features from raw points. Default to
None.
bev_feat_channel (int): Bev features channels num.
Default to 256.
bev_scale_factor (int): Bev features scale factor. Default to 8.
voxel_center_as_source (bool): Whether used voxel centers as points
cloud key points. Defaults to False.
norm_cfg (dict[str]): Config of normalization layer. Default
used dict(type='BN1d', eps=1e-5, momentum=0.1).
bias (bool | str, optional): If specified as `auto`, it will be
decided by `norm_cfg`. `bias` will be set as True if
`norm_cfg` is None, otherwise False. Default: 'auto'.
"""
def __init__(self,
num_keypoints: int,
fused_out_channel: int = 128,
voxel_size: list = [0.05, 0.05, 0.1],
point_cloud_range: list = [0, -40, -3, 70.4, 40, 1],
voxel_sa_cfgs_list: Optional[list] = None,
rawpoints_sa_cfgs: Optional[dict] = None,
bev_feat_channel: int = 256,
bev_scale_factor: int = 8,
voxel_center_as_source: bool = False,
norm_cfg: dict = dict(type='BN2d', eps=1e-5, momentum=0.1),
bias: str = 'auto') -> None:
super().__init__()
self.num_keypoints = num_keypoints
self.fused_out_channel = fused_out_channel
self.voxel_size = voxel_size
self.point_cloud_range = point_cloud_range
self.voxel_center_as_source = voxel_center_as_source
gathered_channel = 0
if rawpoints_sa_cfgs is not None:
self.rawpoints_sa_layer = MODELS.build(rawpoints_sa_cfgs)
gathered_channel += sum(
[x[-1] for x in rawpoints_sa_cfgs.mlp_channels])
else:
self.rawpoints_sa_layer = None
if voxel_sa_cfgs_list is not None:
self.voxel_sa_configs_list = voxel_sa_cfgs_list
self.voxel_sa_layers = nn.ModuleList()
for voxel_sa_config in voxel_sa_cfgs_list:
cur_layer = MODELS.build(voxel_sa_config)
self.voxel_sa_layers.append(cur_layer)
gathered_channel += sum(
[x[-1] for x in voxel_sa_config.mlp_channels])
else:
self.voxel_sa_layers = None
if bev_feat_channel is not None and bev_scale_factor is not None:
self.bev_cfg = mmengine.Config(
dict(
bev_feat_channels=bev_feat_channel,
bev_scale_factor=bev_scale_factor))
gathered_channel += bev_feat_channel
else:
self.bev_cfg = None
self.point_feature_fusion_layer = nn.Sequential(
ConvModule(
gathered_channel,
fused_out_channel,
kernel_size=(1, 1),
stride=(1, 1),
conv_cfg=dict(type='Conv2d'),
norm_cfg=norm_cfg,
bias=bias))
def interpolate_from_bev_features(self, keypoints: torch.Tensor,
bev_features: torch.Tensor,
batch_size: int,
bev_scale_factor: int) -> torch.Tensor:
"""Gather key points features from bev feature map by interpolate.
Args:
keypoints (torch.Tensor): Sampled key points with shape
(N1 + N2 + ..., NDim).
bev_features (torch.Tensor): Bev feature map from the first
stage with shape (B, C, H, W).
batch_size (int): Input batch size.
bev_scale_factor (int): Bev feature map scale factor.
Returns:
torch.Tensor: Key points features gather from bev feature
map with shape (N1 + N2 + ..., C)
"""
x_idxs = (keypoints[..., 0] -
self.point_cloud_range[0]) / self.voxel_size[0]
y_idxs = (keypoints[..., 1] -
self.point_cloud_range[1]) / self.voxel_size[1]
x_idxs = x_idxs / bev_scale_factor
y_idxs = y_idxs / bev_scale_factor
point_bev_features_list = []
for k in range(batch_size):
cur_x_idxs = x_idxs[k, ...]
cur_y_idxs = y_idxs[k, ...]
cur_bev_features = bev_features[k].permute(1, 2, 0) # (H, W, C)
point_bev_features = bilinear_interpolate_torch(
cur_bev_features, cur_x_idxs, cur_y_idxs)
point_bev_features_list.append(point_bev_features)
point_bev_features = torch.cat(
point_bev_features_list, dim=0) # (N1 + N2 + ..., C)
return point_bev_features.view(batch_size, keypoints.shape[1], -1)
def get_voxel_centers(self, coors: torch.Tensor,
scale_factor: float) -> torch.Tensor:
"""Get voxel centers coordinate.
Args:
coors (torch.Tensor): Coordinates of voxels shape is Nx(1+NDim),
where 1 represents the batch index.
scale_factor (float): Scale factor.
Returns:
torch.Tensor: Voxel centers coordinate with shape (N, 3).
"""
assert coors.shape[1] == 4
voxel_centers = coors[:, [3, 2, 1]].float() # (xyz)
voxel_size = torch.tensor(
self.voxel_size,
device=voxel_centers.device).float() * scale_factor
pc_range = torch.tensor(
self.point_cloud_range[0:3], device=voxel_centers.device).float()
voxel_centers = (voxel_centers + 0.5) * voxel_size + pc_range
return voxel_centers
def sample_key_points(self, points: List[torch.Tensor],
coors: torch.Tensor) -> torch.Tensor:
"""Sample key points from raw points cloud.
Args:
points (List[torch.Tensor]): Point cloud of each sample.
coors (torch.Tensor): Coordinates of voxels shape is Nx(1+NDim),
where 1 represents the batch index.
Returns:
torch.Tensor: (B, M, 3) Key points of each sample.
M is num_keypoints.
"""
assert points is not None or coors is not None
if self.voxel_center_as_source:
_src_points = self.get_voxel_centers(coors=coors, scale_factor=1)
batch_size = coors[-1, 0].item() + 1
src_points = [
_src_points[coors[:, 0] == b] for b in range(batch_size)
]
else:
src_points = [p[..., :3] for p in points]
keypoints_list = []
for points_to_sample in src_points:
num_points = points_to_sample.shape[0]
cur_pt_idxs = furthest_point_sample(
points_to_sample.unsqueeze(dim=0).contiguous(),
self.num_keypoints).long()[0]
if num_points < self.num_keypoints:
times = int(self.num_keypoints / num_points) + 1
non_empty = cur_pt_idxs[:num_points]
cur_pt_idxs = non_empty.repeat(times)[:self.num_keypoints]
keypoints = points_to_sample[cur_pt_idxs]
keypoints_list.append(keypoints)
keypoints = torch.stack(keypoints_list, dim=0) # (B, M, 3)
return keypoints
def forward(self, batch_inputs_dict: dict, feats_dict: dict,
rpn_results_list: InstanceList) -> dict:
"""Extract point-wise features from multi-input.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
feats_dict (dict): Contains features from the first
stage.
rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head.
Returns:
dict: Contain Point-wise features, include:
- keypoints (torch.Tensor): Sampled key points.
- keypoint_features (torch.Tensor): Gathered key points
features from multi input.
- fusion_keypoint_features (torch.Tensor): Fusion
keypoint_features by point_feature_fusion_layer.
"""
points = batch_inputs_dict['points']
voxel_encode_features = feats_dict['multi_scale_3d_feats']
bev_encode_features = feats_dict['spatial_feats']
if self.voxel_center_as_source:
voxels_coors = batch_inputs_dict['voxels']['coors']
else:
voxels_coors = None
keypoints = self.sample_key_points(points, voxels_coors)
point_features_list = []
batch_size = len(points)
if self.bev_cfg is not None:
point_bev_features = self.interpolate_from_bev_features(
keypoints, bev_encode_features, batch_size,
self.bev_cfg.bev_scale_factor)
point_features_list.append(point_bev_features.contiguous())
batch_size, num_keypoints, _ = keypoints.shape
key_xyz = keypoints.view(-1, 3)
key_xyz_batch_cnt = key_xyz.new_zeros(batch_size).int().fill_(
num_keypoints)
if self.rawpoints_sa_layer is not None:
batch_points = torch.cat(points, dim=0)
batch_cnt = [len(p) for p in points]
xyz = batch_points[:, :3].contiguous()
features = None
if batch_points.size(1) > 0:
features = batch_points[:, 3:].contiguous()
xyz_batch_cnt = xyz.new_tensor(batch_cnt, dtype=torch.int32)
pooled_points, pooled_features = self.rawpoints_sa_layer(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=key_xyz.contiguous(),
new_xyz_batch_cnt=key_xyz_batch_cnt,
features=features.contiguous(),
)
point_features_list.append(pooled_features.contiguous().view(
batch_size, num_keypoints, -1))
if self.voxel_sa_layers is not None:
for k, voxel_sa_layer in enumerate(self.voxel_sa_layers):
cur_coords = voxel_encode_features[k].indices
xyz = self.get_voxel_centers(
coors=cur_coords,
scale_factor=self.voxel_sa_configs_list[k].scale_factor
).contiguous()
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (cur_coords[:, 0] == bs_idx).sum()
pooled_points, pooled_features = voxel_sa_layer(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=key_xyz.contiguous(),
new_xyz_batch_cnt=key_xyz_batch_cnt,
features=voxel_encode_features[k].features.contiguous(),
)
point_features_list.append(pooled_features.contiguous().view(
batch_size, num_keypoints, -1))
point_features = torch.cat(
point_features_list, dim=-1).view(batch_size * num_keypoints, -1,
1)
fusion_point_features = self.point_feature_fusion_layer(
point_features.unsqueeze(dim=-1)).squeeze(dim=-1)
batch_idxs = torch.arange(
batch_size * num_keypoints, device=keypoints.device
) // num_keypoints # batch indexes of each key points
batch_keypoints_xyz = torch.cat(
(batch_idxs.to(key_xyz.dtype).unsqueeze(dim=-1), key_xyz), dim=-1)
return dict(
keypoint_features=point_features.squeeze(dim=-1),
fusion_keypoint_features=fusion_point_features.squeeze(dim=-1),
keypoints=batch_keypoints_xyz)
......@@ -5,10 +5,11 @@ from .h3d_roi_head import H3DRoIHead
from .mask_heads import PointwiseSemanticHead, PrimitiveHead
from .part_aggregation_roi_head import PartAggregationROIHead
from .point_rcnn_roi_head import PointRCNNRoIHead
from .pv_rcnn_roi_head import PVRCNNRoiHead
from .roi_extractors import Single3DRoIAwareExtractor, SingleRoIExtractor
__all__ = [
'Base3DRoIHead', 'PartAggregationROIHead', 'PointwiseSemanticHead',
'Single3DRoIAwareExtractor', 'PartA2BboxHead', 'SingleRoIExtractor',
'H3DRoIHead', 'PrimitiveHead', 'PointRCNNRoIHead'
'H3DRoIHead', 'PrimitiveHead', 'PointRCNNRoIHead', 'PVRCNNRoiHead'
]
......@@ -7,9 +7,10 @@ from mmdet.models.roi_heads.bbox_heads import (BBoxHead, ConvFCBBoxHead,
from .h3d_bbox_head import H3DBboxHead
from .parta2_bbox_head import PartA2BboxHead
from .point_rcnn_bbox_head import PointRCNNBboxHead
from .pv_rcnn_bbox_head import PVRCNNBBoxHead
__all__ = [
'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead',
'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'PartA2BboxHead',
'H3DBboxHead', 'PointRCNNBboxHead'
'H3DBboxHead', 'PointRCNNBboxHead', 'PVRCNNBBoxHead'
]
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
from .foreground_segmentation_head import ForegroundSegmentationHead
from .pointwise_semantic_head import PointwiseSemanticHead
from .primitive_head import PrimitiveHead
__all__ = ['PointwiseSemanticHead', 'PrimitiveHead']
__all__ = [
'PointwiseSemanticHead', 'PrimitiveHead', 'ForegroundSegmentationHead'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple
import torch
from mmcv.cnn.bricks import build_norm_layer
from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch import nn as nn
from mmdet3d.models.builder import build_loss
from mmdet3d.registry import MODELS
from mmdet3d.utils import InstanceList
from mmdet.models.utils import multi_apply
@MODELS.register_module()
class ForegroundSegmentationHead(BaseModule):
"""Foreground segmentation head.
Args:
in_channels (int): The number of input channel.
mlp_channels (tuple[int]): Specify of mlp channels. Defaults
to (256, 256).
extra_width (float): Boxes enlarge width. Default used 0.1.
norm_cfg (dict): Type of normalization method. Defaults to
dict(type='BN1d', eps=1e-5, momentum=0.1).
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
loss_seg (dict): Config of segmentation loss. Defaults to
dict(type='mmdet.FocalLoss')
"""
def __init__(
self,
in_channels: int,
mlp_channels: Tuple[int] = (256, 256),
extra_width: float = 0.1,
norm_cfg: dict = dict(type='BN1d', eps=1e-5, momentum=0.1),
init_cfg: Optional[dict] = None,
loss_seg: dict = dict(
type='mmdet.FocalLoss',
use_sigmoid=True,
reduction='sum',
gamma=2.0,
alpha=0.25,
activated=True,
loss_weight=1.0)
) -> None:
super(ForegroundSegmentationHead, self).__init__(init_cfg=init_cfg)
self.extra_width = extra_width
self.num_classes = 1
self.in_channels = in_channels
self.use_sigmoid_cls = loss_seg.get('use_sigmoid', False)
out_channels = 1
if self.use_sigmoid_cls:
self.out_channels = out_channels
else:
self.out_channels = out_channels + 1
mlps_layers = []
cin = in_channels
for mlp in mlp_channels:
mlps_layers.extend([
nn.Linear(cin, mlp, bias=False),
build_norm_layer(norm_cfg, mlp)[1],
nn.ReLU()
])
cin = mlp
mlps_layers.append(nn.Linear(cin, self.out_channels, bias=True))
self.seg_cls_layer = nn.Sequential(*mlps_layers)
self.loss_seg = build_loss(loss_seg)
def forward(self, feats: torch.Tensor) -> dict:
"""Forward head.
Args:
feats (torch.Tensor): Point-wise features.
Returns:
dict: Segment predictions.
"""
seg_preds = self.seg_cls_layer(feats)
return dict(seg_preds=seg_preds)
def _get_targets_single(self, point_xyz: torch.Tensor,
gt_bboxes_3d: InstanceData,
gt_labels_3d: torch.Tensor) -> torch.Tensor:
"""generate segmentation targets for a single sample.
Args:
point_xyz (torch.Tensor): Coordinate of points.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth boxes in
shape (box_num, 7).
gt_labels_3d (torch.Tensor): Class labels of ground truths in
shape (box_num).
Returns:
torch.Tensor: Points class labels.
"""
point_cls_labels_single = point_xyz.new_zeros(
point_xyz.shape[0]).long()
enlarged_gt_boxes = gt_bboxes_3d.enlarged_box(self.extra_width)
box_idxs_of_pts = gt_bboxes_3d.points_in_boxes_part(point_xyz).long()
extend_box_idxs_of_pts = enlarged_gt_boxes.points_in_boxes_part(
point_xyz).long()
box_fg_flag = box_idxs_of_pts >= 0
fg_flag = box_fg_flag.clone()
ignore_flag = fg_flag ^ (extend_box_idxs_of_pts >= 0)
point_cls_labels_single[ignore_flag] = -1
gt_box_of_fg_points = gt_labels_3d[box_idxs_of_pts[fg_flag]]
point_cls_labels_single[
fg_flag] = 1 if self.num_classes == 1 else\
gt_box_of_fg_points.long()
return point_cls_labels_single,
def get_targets(self, points_bxyz: torch.Tensor,
batch_gt_instances_3d: InstanceList) -> dict:
"""Generate segmentation targets.
Args:
points_bxyz (torch.Tensor): The coordinates of point in shape
(B, num_points, 3).
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
Returns:
dict: Prediction targets
- seg_targets (torch.Tensor): Segmentation targets.
"""
batch_size = len(batch_gt_instances_3d)
points_xyz_list = []
gt_bboxes_3d = []
gt_labels_3d = []
for idx in range(batch_size):
coords_idx = points_bxyz[:, 0] == idx
points_xyz_list.append(points_bxyz[coords_idx][..., 1:])
gt_bboxes_3d.append(batch_gt_instances_3d[idx].bboxes_3d)
gt_labels_3d.append(batch_gt_instances_3d[idx].labels_3d)
seg_targets, = multi_apply(self._get_targets_single, points_xyz_list,
gt_bboxes_3d, gt_labels_3d)
seg_targets = torch.cat(seg_targets, dim=0)
return dict(seg_targets=seg_targets)
def loss(self, semantic_results: dict,
semantic_targets: dict) -> Dict[str, torch.Tensor]:
"""Calculate point-wise segmentation losses.
Args:
semantic_results (dict): Results from semantic head.
semantic_targets (dict): Targets of semantic results.
Returns:
dict: Loss of segmentation.
- loss_semantic (torch.Tensor): Segmentation prediction loss.
"""
seg_preds = semantic_results['seg_preds']
seg_targets = semantic_targets['seg_targets']
positives = (seg_targets > 0)
negative_cls_weights = (seg_targets == 0).float()
seg_weights = (negative_cls_weights + 1.0 * positives).float()
pos_normalizer = positives.sum(dim=0).float()
seg_weights /= torch.clamp(pos_normalizer, min=1.0)
seg_preds = torch.sigmoid(seg_preds)
loss_seg = self.loss_seg(seg_preds, (~positives).long(), seg_weights)
return dict(loss_semantic=loss_seg)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import torch
from torch.nn import functional as F
from mmdet3d.models.roi_heads.base_3droi_head import Base3DRoIHead
from mmdet3d.registry import MODELS
from mmdet3d.structures import bbox3d2roi
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import InstanceList
from mmdet.models.task_modules import AssignResult
from mmdet.models.task_modules.samplers import SamplingResult
@MODELS.register_module()
class PVRCNNRoiHead(Base3DRoIHead):
"""RoI head for PV-RCNN.
Args:
num_classes (int): The number of classes. Defaults to 3.
semantic_head (dict, optional): Config of semantic head.
Defaults to None.
bbox_roi_extractor (dict, optional): Config of roi_extractor.
Defaults to None.
bbox_head (dict, optional): Config of bbox_head. Defaults to None.
train_cfg (dict, optional): Train config of model.
Defaults to None.
test_cfg (dict, optional): Train config of model.
Defaults to None.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
"""
def __init__(self,
num_classes: int = 3,
semantic_head: Optional[dict] = None,
bbox_roi_extractor: Optional[dict] = None,
bbox_head: Optional[dict] = None,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None):
super(PVRCNNRoiHead, self).__init__(
bbox_head=bbox_head,
bbox_roi_extractor=bbox_roi_extractor,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg)
self.num_classes = num_classes
self.semantic_head = MODELS.build(semantic_head)
self.init_assigner_sampler()
@property
def with_semantic(self):
"""bool: whether the head has semantic branch"""
return hasattr(self,
'semantic_head') and self.semantic_head is not None
def loss(self, feats_dict: dict, rpn_results_list: InstanceList,
batch_data_samples: SampleList, **kwargs) -> dict:
"""Training forward function of PVRCNNROIHead.
Args:
feats_dict (dict): Contains point-wise features.
rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
dict: losses from each head.
- loss_semantic (torch.Tensor): loss of semantic head.
- loss_bbox (torch.Tensor): loss of bboxes.
- loss_cls (torch.Tensor): loss of object classification.
- loss_corner (torch.Tensor): loss of bboxes corners.
"""
losses = dict()
batch_gt_instances_3d = []
batch_gt_instances_ignore = []
for data_sample in batch_data_samples:
batch_gt_instances_3d.append(data_sample.gt_instances_3d)
if 'ignored_instances' in data_sample:
batch_gt_instances_ignore.append(data_sample.ignored_instances)
else:
batch_gt_instances_ignore.append(None)
if self.with_semantic:
semantic_results = self._semantic_forward_train(
feats_dict['keypoint_features'], feats_dict['keypoints'],
batch_gt_instances_3d)
losses['loss_semantic'] = semantic_results['loss_semantic']
sample_results = self._assign_and_sample(rpn_results_list,
batch_gt_instances_3d)
if self.with_bbox:
bbox_results = self._bbox_forward_train(
semantic_results['seg_preds'],
feats_dict['fusion_keypoint_features'],
feats_dict['keypoints'], sample_results)
losses.update(bbox_results['loss_bbox'])
return losses
def predict(self, feats_dict: dict, rpn_results_list: InstanceList,
batch_data_samples: SampleList, **kwargs) -> SampleList:
"""Perform forward propagation of the roi head and predict detection
results on the features of the upstream network.
Args:
feats_dict (dict): Contains point-wise features.
rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
list[:obj:`InstanceData`]: Detection results of each sample
after the post process.
Each item usually contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes,
contains a tensor with shape (num_instances, C), where
C >= 7.
"""
assert self.with_bbox, 'Bbox head must be implemented.'
assert self.with_semantic, 'Semantic head must be implemented.'
batch_input_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
semantic_results = self.semantic_head(feats_dict['keypoint_features'])
point_features = feats_dict[
'fusion_keypoint_features'] * semantic_results[
'seg_preds'].sigmoid().max(
dim=-1, keepdim=True).values
rois = bbox3d2roi(
[res['bboxes_3d'].tensor for res in rpn_results_list])
labels_3d = [res['labels_3d'] for res in rpn_results_list]
bbox_results = self._bbox_forward(point_features,
feats_dict['keypoints'], rois)
results_list = self.bbox_head.get_results(rois,
bbox_results['bbox_scores'],
bbox_results['bbox_reg'],
labels_3d, batch_input_metas,
self.test_cfg)
return results_list
def _bbox_forward_train(self, seg_preds: torch.Tensor,
keypoint_features: torch.Tensor,
keypoints: torch.Tensor,
sampling_results: SamplingResult) -> dict:
"""Forward training function of roi_extractor and bbox_head.
Args:
seg_preds (torch.Tensor): Point-wise semantic features.
keypoint_features (torch.Tensor): key points features
from points encoder.
keypoints (torch.Tensor): Coordinate of key points.
sampling_results (:obj:`SamplingResult`): Sampled results used
for training.
Returns:
dict: Forward results including losses and predictions.
"""
rois = bbox3d2roi([res.bboxes for res in sampling_results])
keypoint_features = keypoint_features * seg_preds.sigmoid().max(
dim=-1, keepdim=True).values
bbox_results = self._bbox_forward(keypoint_features, keypoints, rois)
bbox_targets = self.bbox_head.get_targets(sampling_results,
self.train_cfg)
loss_bbox = self.bbox_head.loss(bbox_results['bbox_scores'],
bbox_results['bbox_reg'], rois,
*bbox_targets)
bbox_results.update(loss_bbox=loss_bbox)
return bbox_results
def _bbox_forward(self, keypoint_features: torch.Tensor,
keypoints: torch.Tensor, rois: torch.Tensor) -> dict:
"""Forward function of roi_extractor and bbox_head used in both
training and testing.
Args:
rois (Tensor): Roi boxes.
keypoint_features (torch.Tensor): key points features
from points encoder.
keypoints (torch.Tensor): Coordinate of key points.
rois (Tensor): Roi boxes.
Returns:
dict: Contains predictions of bbox_head and
features of roi_extractor.
"""
pooled_keypoint_features = self.bbox_roi_extractor(
keypoint_features, keypoints[..., 1:], keypoints[..., 0].int(),
rois)
bbox_score, bbox_reg = self.bbox_head(pooled_keypoint_features)
bbox_results = dict(bbox_scores=bbox_score, bbox_reg=bbox_reg)
return bbox_results
def _assign_and_sample(
self, proposal_list: InstanceList,
batch_gt_instances_3d: InstanceList) -> List[SamplingResult]:
"""Assign and sample proposals for training.
Args:
proposal_list (list[:obj:`InstancesData`]): Proposals produced by
rpn head.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
Returns:
list[:obj:`SamplingResult`]: Sampled results of each training
sample.
"""
sampling_results = []
# bbox assign
for batch_idx in range(len(proposal_list)):
cur_proposal_list = proposal_list[batch_idx]
cur_boxes = cur_proposal_list['bboxes_3d']
cur_labels_3d = cur_proposal_list['labels_3d']
cur_gt_instances_3d = batch_gt_instances_3d[batch_idx]
cur_gt_instances_3d.bboxes_3d = cur_gt_instances_3d.\
bboxes_3d.tensor
cur_gt_bboxes = batch_gt_instances_3d[batch_idx].bboxes_3d.to(
cur_boxes.device)
cur_gt_labels = batch_gt_instances_3d[batch_idx].labels_3d
batch_num_gts = 0
# 0 is bg
batch_gt_indis = cur_gt_labels.new_full((len(cur_boxes), ), 0)
batch_max_overlaps = cur_boxes.tensor.new_zeros(len(cur_boxes))
# -1 is bg
batch_gt_labels = cur_gt_labels.new_full((len(cur_boxes), ), -1)
# each class may have its own assigner
if isinstance(self.bbox_assigner, list):
for i, assigner in enumerate(self.bbox_assigner):
gt_per_cls = (cur_gt_labels == i)
pred_per_cls = (cur_labels_3d == i)
cur_assign_res = assigner.assign(
cur_proposal_list[pred_per_cls],
cur_gt_instances_3d[gt_per_cls])
# gather assign_results in different class into one result
batch_num_gts += cur_assign_res.num_gts
# gt inds (1-based)
gt_inds_arange_pad = gt_per_cls.nonzero(
as_tuple=False).view(-1) + 1
# pad 0 for indice unassigned
gt_inds_arange_pad = F.pad(
gt_inds_arange_pad, (1, 0), mode='constant', value=0)
# pad -1 for indice ignore
gt_inds_arange_pad = F.pad(
gt_inds_arange_pad, (1, 0), mode='constant', value=-1)
# convert to 0~gt_num+2 for indices
gt_inds_arange_pad += 1
# now 0 is bg, >1 is fg in batch_gt_indis
batch_gt_indis[pred_per_cls] = gt_inds_arange_pad[
cur_assign_res.gt_inds + 1] - 1
batch_max_overlaps[
pred_per_cls] = cur_assign_res.max_overlaps
batch_gt_labels[pred_per_cls] = cur_assign_res.labels
assign_result = AssignResult(batch_num_gts, batch_gt_indis,
batch_max_overlaps,
batch_gt_labels)
else: # for single class
assign_result = self.bbox_assigner.assign(
cur_proposal_list, cur_gt_instances_3d)
# sample boxes
sampling_result = self.bbox_sampler.sample(assign_result,
cur_boxes.tensor,
cur_gt_bboxes,
cur_gt_labels)
sampling_results.append(sampling_result)
return sampling_results
def _semantic_forward_train(self, keypoint_features: torch.Tensor,
keypoints: torch.Tensor,
batch_gt_instances_3d: InstanceList) -> dict:
"""Train semantic head.
Args:
keypoint_features (torch.Tensor): key points features
from points encoder.
keypoints (torch.Tensor): Coordinate of key points.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
Returns:
dict: Segmentation results including losses
"""
semantic_results = self.semantic_head(keypoint_features)
semantic_targets = self.semantic_head.get_targets(
keypoints, batch_gt_instances_3d)
loss_semantic = self.semantic_head.loss(semantic_results,
semantic_targets)
semantic_results.update(loss_semantic)
return semantic_results
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models.roi_heads.roi_extractors import SingleRoIExtractor
from .batch_roigridpoint_extractor import Batch3DRoIGridExtractor
from .single_roiaware_extractor import Single3DRoIAwareExtractor
from .single_roipoint_extractor import Single3DRoIPointExtractor
__all__ = [
'SingleRoIExtractor', 'Single3DRoIAwareExtractor',
'Single3DRoIPointExtractor'
'Single3DRoIPointExtractor', 'Batch3DRoIGridExtractor'
]
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.model import BaseModule
from mmdet3d.registry import MODELS
from mmdet3d.structures.bbox_3d import rotation_3d_in_axis
@MODELS.register_module()
class Batch3DRoIGridExtractor(BaseModule):
"""Grid point wise roi-aware Extractor.
Args:
grid_size (int): The number of grid points in a roi bbox.
Defaults to 6.
roi_layer (dict, optional): Config of sa module to get
grid points features. Defaults to None.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
"""
def __init__(self,
grid_size: int = 6,
roi_layer: dict = None,
init_cfg: dict = None) -> None:
super(Batch3DRoIGridExtractor, self).__init__(init_cfg=init_cfg)
self.roi_grid_pool_layer = MODELS.build(roi_layer)
self.grid_size = grid_size
def forward(self, feats: torch.Tensor, coordinate: torch.Tensor,
batch_inds: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
"""Forward roi extractor to extract grid points feature.
Args:
feats (torch.Tensor): Key points features.
coordinate (torch.Tensor): Key points coordinates.
batch_inds (torch.Tensor): Input batch indexes.
rois (torch.Tensor): Detection results of rpn head.
Returns:
torch.Tensor: Grid points features.
"""
batch_size = int(batch_inds.max()) + 1
xyz = coordinate
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for k in range(batch_size):
xyz_batch_cnt[k] = (batch_inds == k).sum()
rois_batch_inds = rois[:, 0].int()
# (N1+N2+..., 6x6x6, 3)
roi_grid = self.get_dense_grid_points(rois[:, 1:])
new_xyz = roi_grid.view(-1, 3)
new_xyz_batch_cnt = new_xyz.new_zeros(batch_size).int()
for k in range(batch_size):
new_xyz_batch_cnt[k] = ((rois_batch_inds == k).sum() *
roi_grid.size(1))
pooled_points, pooled_features = self.roi_grid_pool_layer(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz.contiguous(),
new_xyz_batch_cnt=new_xyz_batch_cnt,
features=feats.contiguous()) # (M1 + M2 ..., C)
pooled_features = pooled_features.view(-1, self.grid_size,
self.grid_size, self.grid_size,
pooled_features.shape[-1])
# (BxN, 6, 6, 6, C)
return pooled_features
def get_dense_grid_points(self, rois: torch.Tensor) -> torch.Tensor:
"""Get dense grid points from rois.
Args:
rois (torch.Tensor): Detection results of rpn head.
Returns:
torch.Tensor: Grid points coordinates.
"""
rois_bbox = rois.clone()
rois_bbox[:, 2] += rois_bbox[:, 5] / 2
faked_features = rois_bbox.new_ones(
(self.grid_size, self.grid_size, self.grid_size))
dense_idx = faked_features.nonzero()
dense_idx = dense_idx.repeat(rois_bbox.size(0), 1, 1).float()
dense_idx = ((dense_idx + 0.5) / self.grid_size)
dense_idx[..., :3] -= 0.5
roi_ctr = rois_bbox[:, :3]
roi_dim = rois_bbox[:, 3:6]
roi_grid_points = dense_idx * roi_dim.view(-1, 1, 3)
roi_grid_points = rotation_3d_in_axis(
roi_grid_points, rois_bbox[:, 6], axis=2)
roi_grid_points += roi_ctr.view(-1, 1, 3)
return roi_grid_points
import unittest
import torch
from mmengine import DefaultScope
from mmdet3d.registry import MODELS
from tests.utils.model_utils import (_create_detector_inputs,
_get_detector_cfg, _setup_seed)
class TestPVRCNN(unittest.TestCase):
def test_pvrcnn(self):
import mmdet3d.models
assert hasattr(mmdet3d.models, 'PointVoxelRCNN')
DefaultScope.get_instance('test_pvrcnn', scope_name='mmdet3d')
_setup_seed(0)
pvrcnn_cfg = _get_detector_cfg(
'pvrcnn/pvrcnn_8xb2-80e_kitti-3d-3class.py')
model = MODELS.build(pvrcnn_cfg)
num_gt_instance = 2
packed_inputs = _create_detector_inputs(
num_gt_instance=num_gt_instance)
# TODO: Support aug data test
# aug_packed_inputs = [
# _create_detector_inputs(num_gt_instance=num_gt_instance),
# _create_detector_inputs(num_gt_instance=num_gt_instance + 1)
# ]
# test_aug_test
# metainfo = {
# 'pcd_scale_factor': 1,
# 'pcd_horizontal_flip': 1,
# 'pcd_vertical_flip': 1,
# 'box_type_3d': LiDARInstance3DBoxes
# }
# for item in aug_packed_inputs:
# for batch_id in len(item['data_samples']):
# item['data_samples'][batch_id].set_metainfo(metainfo)
if torch.cuda.is_available():
model = model.cuda()
# test simple_test
with torch.no_grad():
data = model.data_preprocessor(packed_inputs, True)
torch.cuda.empty_cache()
results = model.forward(**data, mode='predict')
self.assertEqual(len(results), 1)
self.assertIn('bboxes_3d', results[0].pred_instances_3d)
self.assertIn('scores_3d', results[0].pred_instances_3d)
self.assertIn('labels_3d', results[0].pred_instances_3d)
# save the memory
with torch.no_grad():
losses = model.forward(**data, mode='loss')
torch.cuda.empty_cache()
self.assertGreater(losses['loss_rpn_cls'][0], 0)
self.assertGreaterEqual(losses['loss_rpn_bbox'][0], 0)
self.assertGreaterEqual(losses['loss_rpn_dir'][0], 0)
self.assertGreater(losses['loss_semantic'], 0)
self.assertGreaterEqual(losses['loss_bbox'], 0)
self.assertGreaterEqual(losses['loss_cls'], 0)
self.assertGreaterEqual(losses['loss_corner'], 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment