Commit aa544d4a authored by VVsssssk's avatar VVsssssk Committed by ZwwWayne
Browse files

[Features] Support PV_RCNN modules (#1957)

* add pvrcnn module code

* add voxelsa

* fix

* fix comments

* fix comments

* fix comments

* add stack sa

* fix

* fix comments

* fix comments

* fix

* add ut

* fix comments
parent 1fd71531
_base_ = [
'../_base_/datasets/kitti-3d-3class.py',
'../_base_/schedules/cyclic-40e.py', '../_base_/default_runtime.py'
]
voxel_size = [0.05, 0.05, 0.1]
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
data_root = 'data/kitti/'
class_names = ['Pedestrian', 'Cyclist', 'Car']
metainfo = dict(CLASSES=class_names)
db_sampler = dict(
data_root=data_root,
info_path=data_root + 'kitti_dbinfos_train.pkl',
rate=1.0,
prepare=dict(
filter_by_difficulty=[-1],
filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)),
classes=class_names,
sample_groups=dict(Car=15, Pedestrian=10, Cyclist=10),
points_loader=dict(
type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4))
train_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler, use_ground_plane=True),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05]),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'),
dict(
type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
model = dict(
type='PointVoxelRCNN',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
voxel_layer=dict(
max_num_points=5, # max_points_per_voxel
point_cloud_range=point_cloud_range,
voxel_size=voxel_size,
max_voxels=(16000, 40000))),
voxel_encoder=dict(type='HardSimpleVFE'),
middle_encoder=dict(
type='SparseEncoder',
in_channels=4,
sparse_shape=[41, 1600, 1408],
order=('conv', 'norm', 'act'),
encoder_paddings=((0, 0, 0), ((1, 1, 1), 0, 0), ((1, 1, 1), 0, 0),
((0, 1, 1), 0, 0)),
return_middle_feats=True),
points_encoder=dict(
type='VoxelSetAbstraction',
num_keypoints=2048,
fused_out_channel=128,
voxel_size=voxel_size,
point_cloud_range=point_cloud_range,
voxel_sa_cfgs_list=[
dict(
type='StackedSAModuleMSG',
in_channels=16,
scale_factor=1,
radius=(0.4, 0.8),
sample_nums=(16, 16),
mlp_channels=((16, 16), (16, 16)),
use_xyz=True),
dict(
type='StackedSAModuleMSG',
in_channels=32,
scale_factor=2,
radius=(0.8, 1.2),
sample_nums=(16, 32),
mlp_channels=((32, 32), (32, 32)),
use_xyz=True),
dict(
type='StackedSAModuleMSG',
in_channels=64,
scale_factor=4,
radius=(1.2, 2.4),
sample_nums=(16, 32),
mlp_channels=((64, 64), (64, 64)),
use_xyz=True),
dict(
type='StackedSAModuleMSG',
in_channels=64,
scale_factor=8,
radius=(2.4, 4.8),
sample_nums=(16, 32),
mlp_channels=((64, 64), (64, 64)),
use_xyz=True)
],
rawpoints_sa_cfgs=dict(
type='StackedSAModuleMSG',
in_channels=1,
radius=(0.4, 0.8),
sample_nums=(16, 16),
mlp_channels=((16, 16), (16, 16)),
use_xyz=True),
bev_feat_channel=256,
bev_scale_factor=8),
backbone=dict(
type='SECOND',
in_channels=256,
layer_nums=[5, 5],
layer_strides=[1, 2],
out_channels=[128, 256]),
neck=dict(
type='SECONDFPN',
in_channels=[128, 256],
upsample_strides=[1, 2],
out_channels=[256, 256]),
rpn_head=dict(
type='PartA2RPNHead',
num_classes=3,
in_channels=512,
feat_channels=512,
use_direction_classifier=True,
dir_offset=0.78539,
anchor_generator=dict(
type='Anchor3DRangeGenerator',
ranges=[[0, -40.0, -0.6, 70.4, 40.0, -0.6],
[0, -40.0, -0.6, 70.4, 40.0, -0.6],
[0, -40.0, -1.78, 70.4, 40.0, -1.78]],
sizes=[[0.8, 0.6, 1.73], [1.76, 0.6, 1.73], [3.9, 1.6, 1.56]],
rotations=[0, 1.57],
reshape_out=False),
diff_rad_by_sin=True,
assigner_per_size=True,
assign_per_class=True,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_cls=dict(
type='mmdet.FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(
type='mmdet.SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_dir=dict(
type='mmdet.CrossEntropyLoss', use_sigmoid=False,
loss_weight=0.2)),
roi_head=dict(
type='PVRCNNRoiHead',
num_classes=3,
semantic_head=dict(
type='ForegroundSegmentationHead',
in_channels=640,
extra_width=0.1,
loss_seg=dict(
type='mmdet.FocalLoss',
use_sigmoid=True,
reduction='sum',
gamma=2.0,
alpha=0.25,
activated=True,
loss_weight=1.0)),
bbox_roi_extractor=dict(
type='Batch3DRoIGridExtractor',
grid_size=6,
roi_layer=dict(
type='StackedSAModuleMSG',
in_channels=128,
radius=(0.8, 1.6),
sample_nums=(16, 16),
mlp_channels=((64, 64), (64, 64)),
use_xyz=True,
pool_mod='max'),
),
bbox_head=dict(
type='PVRCNNBBoxHead',
in_channels=128,
grid_size=6,
num_classes=3,
class_agnostic=True,
shared_fc_channels=(256, 256),
reg_channels=(256, 256),
cls_channels=(256, 256),
dropout_ratio=0.3,
with_corner_loss=True,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_bbox=dict(
type='mmdet.SmoothL1Loss',
beta=1.0 / 9.0,
reduction='sum',
loss_weight=1.0),
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=[
dict( # for Pedestrian
type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.5,
neg_iou_thr=0.35,
min_pos_iou=0.35,
ignore_iof_thr=-1),
dict( # for Cyclist
type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.5,
neg_iou_thr=0.35,
min_pos_iou=0.35,
ignore_iof_thr=-1),
dict( # for Car
type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.6,
neg_iou_thr=0.45,
min_pos_iou=0.45,
ignore_iof_thr=-1)
],
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=9000,
nms_post=512,
max_num=512,
nms_thr=0.8,
score_thr=0,
use_rotate_nms=True),
rcnn=dict(
assigner=[
dict( # for Pedestrian
type='Max3DIoUAssigner',
iou_calculator=dict(
type='BboxOverlaps3D', coordinate='lidar'),
pos_iou_thr=0.55,
neg_iou_thr=0.55,
min_pos_iou=0.55,
ignore_iof_thr=-1),
dict( # for Cyclist
type='Max3DIoUAssigner',
iou_calculator=dict(
type='BboxOverlaps3D', coordinate='lidar'),
pos_iou_thr=0.55,
neg_iou_thr=0.55,
min_pos_iou=0.55,
ignore_iof_thr=-1),
dict( # for Car
type='Max3DIoUAssigner',
iou_calculator=dict(
type='BboxOverlaps3D', coordinate='lidar'),
pos_iou_thr=0.55,
neg_iou_thr=0.55,
min_pos_iou=0.55,
ignore_iof_thr=-1)
],
sampler=dict(
type='IoUNegPiecewiseSampler',
num=128,
pos_fraction=0.5,
neg_piece_fractions=[0.8, 0.2],
neg_iou_piece_thrs=[0.55, 0.1],
neg_pos_ub=-1,
add_gt_as_proposals=False,
return_iou=True),
cls_pos_thr=0.75,
cls_neg_thr=0.25)),
test_cfg=dict(
rpn=dict(
nms_pre=1024,
nms_post=100,
max_num=100,
nms_thr=0.7,
score_thr=0,
use_rotate_nms=True),
rcnn=dict(
use_rotate_nms=True,
use_raw_score=True,
nms_thr=0.1,
score_thr=0.1)))
train_dataloader = dict(
batch_size=2,
num_workers=2,
dataset=dict(dataset=dict(pipeline=train_pipeline, metainfo=metainfo)))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo))
eval_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo))
lr = 0.001
optim_wrapper = dict(optimizer=dict(lr=lr))
param_scheduler = [
# learning rate scheduler
# During the first 16 epochs, learning rate increases from 0 to lr * 10
# during the next 24 epochs, learning rate decreases from lr * 10 to
# lr * 1e-4
dict(
type='CosineAnnealingLR',
T_max=15,
eta_min=lr * 10,
begin=0,
end=15,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=25,
eta_min=lr * 1e-4,
begin=15,
end=40,
by_epoch=True,
convert_to_iter_based=True),
# momentum scheduler
# During the first 16 epochs, momentum increases from 0 to 0.85 / 0.95
# during the next 24 epochs, momentum increases from 0.85 / 0.95 to 1
dict(
type='CosineAnnealingMomentum',
T_max=15,
eta_min=0.85 / 0.95,
begin=0,
end=15,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingMomentum',
T_max=25,
eta_min=1,
begin=15,
end=40,
by_epoch=True,
convert_to_iter_based=True)
]
...@@ -67,7 +67,6 @@ def init_model(config: Union[str, Path, Config], ...@@ -67,7 +67,6 @@ def init_model(config: Union[str, Path, Config],
if checkpoint is not None: if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
dataset_meta = checkpoint['meta'].get('dataset_meta', None) dataset_meta = checkpoint['meta'].get('dataset_meta', None)
# save the dataset_meta in the model for convenience # save the dataset_meta in the model for convenience
if 'dataset_meta' in checkpoint.get('meta', {}): if 'dataset_meta' in checkpoint.get('meta', {}):
......
...@@ -14,6 +14,7 @@ from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN ...@@ -14,6 +14,7 @@ from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN
from .mvx_two_stage import MVXTwoStageDetector from .mvx_two_stage import MVXTwoStageDetector
from .parta2 import PartA2 from .parta2 import PartA2
from .point_rcnn import PointRCNN from .point_rcnn import PointRCNN
from .pv_rcnn import PointVoxelRCNN
from .sassd import SASSD from .sassd import SASSD
from .single_stage_mono3d import SingleStageMono3DDetector from .single_stage_mono3d import SingleStageMono3DDetector
from .smoke_mono3d import SMOKEMono3D from .smoke_mono3d import SMOKEMono3D
...@@ -26,5 +27,6 @@ __all__ = [ ...@@ -26,5 +27,6 @@ __all__ = [
'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet', 'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet',
'CenterPoint', 'SSD3DNet', 'ImVoteNet', 'SingleStageMono3DDetector', 'CenterPoint', 'SSD3DNet', 'ImVoteNet', 'SingleStageMono3DDetector',
'FCOSMono3D', 'ImVoxelNet', 'GroupFree3DNet', 'PointRCNN', 'SMOKEMono3D', 'FCOSMono3D', 'ImVoxelNet', 'GroupFree3DNet', 'PointRCNN', 'SMOKEMono3D',
'SASSD', 'MinkSingleStage3DDetector', 'MultiViewDfM', 'DfM' 'SASSD', 'MinkSingleStage3DDetector', 'MultiViewDfM', 'DfM',
'PointVoxelRCNN'
] ]
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Optional
from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import InstanceList
from .two_stage import TwoStage3DDetector
@MODELS.register_module()
class PointVoxelRCNN(TwoStage3DDetector):
r"""PointVoxelRCNN detector.
Please refer to the `PointVoxelRCNN <https://arxiv.org/abs/1912.13192>`_.
Args:
voxel_encoder (dict): Point voxelization encoder layer.
middle_encoder (dict): Middle encoder layer
of points cloud modality.
backbone (dict): Backbone of extracting points features.
neck (dict, optional): Neck of extracting points features.
Defaults to None.
rpn_head (dict, optional): Config of RPN head. Defaults to None.
points_encoder (dict, optional): Points encoder to extract point-wise
features. Defaults to None.
roi_head (dict, optional): Config of ROI head. Defaults to None.
train_cfg (dict, optional): Train config of model.
Defaults to None.
test_cfg (dict, optional): Train config of model.
Defaults to None.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`Det3DDataPreprocessor`. Defaults to None.
"""
def __init__(self,
voxel_encoder: dict,
middle_encoder: dict,
backbone: dict,
neck: Optional[dict] = None,
rpn_head: Optional[dict] = None,
points_encoder: Optional[dict] = None,
roi_head: Optional[dict] = None,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None,
data_preprocessor: Optional[dict] = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg,
data_preprocessor=data_preprocessor)
self.voxel_encoder = MODELS.build(voxel_encoder)
self.middle_encoder = MODELS.build(middle_encoder)
self.points_encoder = MODELS.build(points_encoder)
def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
**kwargs) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
list[:obj:`Det3DDataSample`]: Detection results of the
input samples. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7.
"""
feats_dict = self.extract_feat(batch_inputs_dict)
if self.with_rpn:
rpn_results_list = self.rpn_head.predict(feats_dict,
batch_data_samples)
else:
rpn_results_list = [
data_sample.proposals for data_sample in batch_data_samples
]
# extrack points feats by points_encoder
points_feats_dict = self.extract_points_feat(batch_inputs_dict,
feats_dict,
rpn_results_list)
results_list_3d = self.roi_head.predict(points_feats_dict,
rpn_results_list,
batch_data_samples)
# connvert to Det3DDataSample
results_list = self.add_pred_to_datasample(batch_data_samples,
results_list_3d)
return results_list
def extract_feat(self, batch_inputs_dict: dict) -> dict:
"""Extract features from the input voxels.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
Returns:
dict: We typically obtain a dict of features from the backbone +
neck, it includes:
- spatial_feats (torch.Tensor): Spatial feats from middle
encoder.
- multi_scale_3d_feats (list[torch.Tensor]): Multi scale
middle feats from middle encoder.
- neck_feats (torch.Tensor): Neck feats from neck.
"""
feats_dict = dict()
voxel_dict = batch_inputs_dict['voxels']
voxel_features = self.voxel_encoder(voxel_dict['voxels'],
voxel_dict['num_points'],
voxel_dict['coors'])
batch_size = voxel_dict['coors'][-1, 0].item() + 1
feats_dict['spatial_feats'], feats_dict[
'multi_scale_3d_feats'] = self.middle_encoder(
voxel_features, voxel_dict['coors'], batch_size)
x = self.backbone(feats_dict['spatial_feats'])
if self.with_neck:
neck_feats = self.neck(x)
feats_dict['neck_feats'] = neck_feats
return feats_dict
def extract_points_feat(self, batch_inputs_dict: dict, feats_dict: dict,
rpn_results_list: InstanceList) -> dict:
"""Extract point-wise features from the raw points and voxel features.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
feats_dict (dict): Contains features from the first stage.
rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head.
Returns:
dict: Contain Point-wise features, include:
- keypoints (torch.Tensor): Sampled key points.
- keypoint_features (torch.Tensor): Gather key points features
from multi input.
- fusion_keypoint_features (torch.Tensor): Fusion
keypoint_features by point_feature_fusion_layer.
"""
return self.points_encoder(batch_inputs_dict, feats_dict,
rpn_results_list)
def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
**kwargs):
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
dict: A dictionary of loss components.
"""
feats_dict = self.extract_feat(batch_inputs_dict)
losses = dict()
# RPN forward and loss
if self.with_rpn:
proposal_cfg = self.train_cfg.get('rpn_proposal',
self.test_cfg.rpn)
rpn_data_samples = copy.deepcopy(batch_data_samples)
rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict(
feats_dict,
rpn_data_samples,
proposal_cfg=proposal_cfg,
**kwargs)
# avoid get same name with roi_head loss
keys = rpn_losses.keys()
for key in keys:
if 'loss' in key and 'rpn' not in key:
rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key)
losses.update(rpn_losses)
else:
# TODO: Not support currently, should have a check at Fast R-CNN
assert batch_data_samples[0].get('proposals', None) is not None
# use pre-defined proposals in InstanceData for the second stage
# to extract ROI features.
rpn_results_list = [
data_sample.proposals for data_sample in batch_data_samples
]
points_feats_dict = self.extract_points_feat(batch_inputs_dict,
feats_dict,
rpn_results_list)
roi_losses = self.roi_head.loss(points_feats_dict, rpn_results_list,
batch_data_samples)
losses.update(roi_losses)
return losses
...@@ -4,9 +4,10 @@ from .paconv_sa_module import (PAConvCUDASAModule, PAConvCUDASAModuleMSG, ...@@ -4,9 +4,10 @@ from .paconv_sa_module import (PAConvCUDASAModule, PAConvCUDASAModuleMSG,
PAConvSAModule, PAConvSAModuleMSG) PAConvSAModule, PAConvSAModuleMSG)
from .point_fp_module import PointFPModule from .point_fp_module import PointFPModule
from .point_sa_module import PointSAModule, PointSAModuleMSG from .point_sa_module import PointSAModule, PointSAModuleMSG
from .stack_point_sa_module import StackedSAModuleMSG
__all__ = [ __all__ = [
'build_sa_module', 'PointSAModuleMSG', 'PointSAModule', 'PointFPModule', 'build_sa_module', 'PointSAModuleMSG', 'PointSAModule', 'PointFPModule',
'PAConvSAModule', 'PAConvSAModuleMSG', 'PAConvCUDASAModule', 'PAConvSAModule', 'PAConvSAModuleMSG', 'PAConvCUDASAModule',
'PAConvCUDASAModuleMSG' 'PAConvCUDASAModuleMSG', 'StackedSAModuleMSG'
] ]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.ops import ball_query, grouping_operation
from mmengine.model import BaseModule
from torch import Tensor
from mmdet3d.registry import MODELS
class StackQueryAndGroup(BaseModule):
"""Find nearby points in spherical space.
Args:
radius (float): List of radius in each ball query.
sample_nums (int): Number of samples in each ball query.
use_xyz (bool): Whether to use xyz. Default: True.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
"""
def __init__(self,
radius: float,
sample_nums: int,
use_xyz: bool = True,
init_cfg: dict = None):
super().__init__(init_cfg=init_cfg)
self.radius, self.sample_nums, self.use_xyz = \
radius, sample_nums, use_xyz
def forward(self,
xyz: torch.Tensor,
xyz_batch_cnt: torch.Tensor,
new_xyz: torch.Tensor,
new_xyz_batch_cnt: torch.Tensor,
features: torch.Tensor = None) -> Tuple[Tensor, Tensor]:
"""Forward.
Args:
xyz (Tensor): Tensor of the xyz coordinates
of the features shape with (N1 + N2 ..., 3).
xyz_batch_cnt: (Tensor): Stacked input xyz coordinates nums in
each batch, just like (N1, N2, ...).
new_xyz (Tensor): New coords of the outputs shape with
(M1 + M2 ..., 3).
new_xyz_batch_cnt: (Tensor): Stacked new xyz coordinates nums
in each batch, just like (M1, M2, ...).
features (Tensor, optional): Features of each point with shape
(N1 + N2 ..., C). C is features channel number. Default: None.
"""
assert xyz.shape[0] == xyz_batch_cnt.sum(
), f'xyz: {str(xyz.shape)}, xyz_batch_cnt: str(new_xyz_batch_cnt)'
assert new_xyz.shape[0] == new_xyz_batch_cnt.sum(), \
'new_xyz: str(new_xyz.shape), new_xyz_batch_cnt: ' \
'str(new_xyz_batch_cnt)'
# idx: (M1 + M2 ..., nsample), empty_ball_mask: (M1 + M2 ...)
idx, empty_ball_mask = ball_query(0, self.radius, self.sample_nums,
xyz, new_xyz, xyz_batch_cnt,
new_xyz_batch_cnt)
grouped_xyz = grouping_operation(
xyz, idx, xyz_batch_cnt,
new_xyz_batch_cnt) # (M1 + M2, 3, nsample)
grouped_xyz -= new_xyz.unsqueeze(-1)
grouped_xyz[empty_ball_mask] = 0
if features is not None:
grouped_features = grouping_operation(
features, idx, xyz_batch_cnt,
new_xyz_batch_cnt) # (M1 + M2, C, nsample)
grouped_features[empty_ball_mask] = 0
if self.use_xyz:
new_features = torch.cat(
[grouped_xyz, grouped_features],
dim=1) # (M1 + M2 ..., C + 3, nsample)
else:
new_features = grouped_features
else:
assert self.use_xyz, 'Cannot have not features and not' \
' use xyz as a feature!'
new_features = grouped_xyz
return new_features, idx
@MODELS.register_module()
class StackedSAModuleMSG(BaseModule):
"""Stack point set abstraction module.
Args:
in_channels (int): Input channels.
radius (list[float]): List of radius in each ball query.
sample_nums (list[int]): Number of samples in each ball query.
mlp_channels (list[list[int]]): Specify mlp channels of the
pointnet before the global pooling for each scale to encode
point features.
use_xyz (bool): Whether to use xyz. Default: True.
pool_mod (str): Type of pooling method.
Default: 'max_pool'.
norm_cfg (dict): Type of normalization method. Defaults to
dict(type='BN2d', eps=1e-5, momentum=0.01).
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
"""
def __init__(self,
in_channels: int,
radius: List[float],
sample_nums: List[int],
mlp_channels: List[List[int]],
use_xyz: bool = True,
pool_mod='max',
norm_cfg: dict = dict(type='BN2d', eps=1e-5, momentum=0.01),
init_cfg: dict = None,
**kwargs) -> None:
super(StackedSAModuleMSG, self).__init__(init_cfg=init_cfg)
assert len(radius) == len(sample_nums) == len(mlp_channels)
self.groupers = nn.ModuleList()
self.mlps = nn.ModuleList()
for i in range(len(radius)):
cin = in_channels
if use_xyz:
cin += 3
cur_radius = radius[i]
nsample = sample_nums[i]
mlp_spec = mlp_channels[i]
self.groupers.append(
StackQueryAndGroup(cur_radius, nsample, use_xyz=use_xyz))
mlp = nn.Sequential()
for i in range(len(mlp_spec)):
cout = mlp_spec[i]
mlp.add_module(
f'layer{i}',
ConvModule(
cin,
cout,
kernel_size=(1, 1),
stride=(1, 1),
conv_cfg=dict(type='Conv2d'),
norm_cfg=norm_cfg,
bias=False))
cin = cout
self.mlps.append(mlp)
self.pool_mod = pool_mod
def forward(self,
xyz: Tensor,
xyz_batch_cnt: Tensor,
new_xyz: Tensor,
new_xyz_batch_cnt: Tensor,
features: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""Forward.
Args:
xyz (Tensor): Tensor of the xyz coordinates
of the features shape with (N1 + N2 ..., 3).
xyz_batch_cnt: (Tensor): Stacked input xyz coordinates nums in
each batch, just like (N1, N2, ...).
new_xyz (Tensor): New coords of the outputs shape with
(M1 + M2 ..., 3).
new_xyz_batch_cnt: (Tensor): Stacked new xyz coordinates nums
in each batch, just like (M1, M2, ...).
features (Tensor, optional): Features of each point with shape
(N1 + N2 ..., C). C is features channel number. Default: None.
Returns:
Return new points coordinates and features:
- new_xyz (Tensor): Target points coordinates with shape
(N1 + N2 ..., 3).
- new_features (Tensor): Target points features with shape
(M1 + M2 ..., sum_k(mlps[k][-1])).
"""
new_features_list = []
for k in range(len(self.groupers)):
grouped_features, ball_idxs = self.groupers[k](
xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt,
features) # (M1 + M2, Cin, nsample)
grouped_features = grouped_features.permute(1, 0,
2).unsqueeze(dim=0)
new_features = self.mlps[k](grouped_features)
# (M1 + M2 ..., Cout, nsample)
if self.pool_mod == 'max':
new_features = new_features.max(-1).values
elif self.pool_mod == 'avg':
new_features = new_features.mean(-1)
else:
raise NotImplementedError
new_features = new_features.squeeze(dim=0).permute(1, 0)
new_features_list.append(new_features)
new_features = torch.cat(new_features_list, dim=1)
return new_xyz, new_features
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
from .pillar_scatter import PointPillarsScatter from .pillar_scatter import PointPillarsScatter
from .sparse_encoder import SparseEncoder, SparseEncoderSASSD from .sparse_encoder import SparseEncoder, SparseEncoderSASSD
from .sparse_unet import SparseUNet from .sparse_unet import SparseUNet
from .voxel_set_abstraction import VoxelSetAbstraction
__all__ = [ __all__ = [
'PointPillarsScatter', 'SparseEncoder', 'SparseEncoderSASSD', 'SparseUNet' 'PointPillarsScatter', 'SparseEncoder', 'SparseEncoderSASSD', 'SparseUNet',
'VoxelSetAbstraction'
] ]
...@@ -41,6 +41,8 @@ class SparseEncoder(nn.Module): ...@@ -41,6 +41,8 @@ class SparseEncoder(nn.Module):
Defaults to ((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1)). Defaults to ((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1)).
block_type (str, optional): Type of the block to use. block_type (str, optional): Type of the block to use.
Defaults to 'conv_module'. Defaults to 'conv_module'.
return_middle_feats (bool): Whether output middle features.
Default to False.
""" """
def __init__(self, def __init__(self,
...@@ -54,7 +56,8 @@ class SparseEncoder(nn.Module): ...@@ -54,7 +56,8 @@ class SparseEncoder(nn.Module):
64)), 64)),
encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
1)), 1)),
block_type='conv_module'): block_type='conv_module',
return_middle_feats=False):
super().__init__() super().__init__()
assert block_type in ['conv_module', 'basicblock'] assert block_type in ['conv_module', 'basicblock']
self.sparse_shape = sparse_shape self.sparse_shape = sparse_shape
...@@ -66,6 +69,7 @@ class SparseEncoder(nn.Module): ...@@ -66,6 +69,7 @@ class SparseEncoder(nn.Module):
self.encoder_paddings = encoder_paddings self.encoder_paddings = encoder_paddings
self.stage_num = len(self.encoder_channels) self.stage_num = len(self.encoder_channels)
self.fp16_enabled = False self.fp16_enabled = False
self.return_middle_feats = return_middle_feats
# Spconv init all weight on its own # Spconv init all weight on its own
assert isinstance(order, tuple) and len(order) == 3 assert isinstance(order, tuple) and len(order) == 3
...@@ -117,7 +121,14 @@ class SparseEncoder(nn.Module): ...@@ -117,7 +121,14 @@ class SparseEncoder(nn.Module):
batch_size (int): Batch size. batch_size (int): Batch size.
Returns: Returns:
dict: Backbone features. torch.Tensor | tuple[torch.Tensor, list]: Return spatial features
include:
- spatial_features (torch.Tensor): Spatial features are out from
the last layer.
- encode_features (List[SparseConvTensor], optional): Middle layer
output features. When self.return_middle_feats is True, the
module returns middle features.
""" """
coors = coors.int() coors = coors.int()
input_sp_tensor = SparseConvTensor(voxel_features, coors, input_sp_tensor = SparseConvTensor(voxel_features, coors,
...@@ -137,7 +148,10 @@ class SparseEncoder(nn.Module): ...@@ -137,7 +148,10 @@ class SparseEncoder(nn.Module):
N, C, D, H, W = spatial_features.shape N, C, D, H, W = spatial_features.shape
spatial_features = spatial_features.view(N, C * D, H, W) spatial_features = spatial_features.view(N, C * D, H, W)
return spatial_features if self.return_middle_feats:
return spatial_features, encode_features
else:
return spatial_features
def make_encoder_layers(self, def make_encoder_layers(self,
make_block, make_block,
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import mmengine
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.ops.furthest_point_sample import furthest_point_sample
from mmengine.model import BaseModule
from mmdet3d.registry import MODELS
from mmdet3d.utils import InstanceList
def bilinear_interpolate_torch(inputs, x, y):
"""Bilinear interpolate for inputs."""
x0 = torch.floor(x).long()
x1 = x0 + 1
y0 = torch.floor(y).long()
y1 = y0 + 1
x0 = torch.clamp(x0, 0, inputs.shape[1] - 1)
x1 = torch.clamp(x1, 0, inputs.shape[1] - 1)
y0 = torch.clamp(y0, 0, inputs.shape[0] - 1)
y1 = torch.clamp(y1, 0, inputs.shape[0] - 1)
Ia = inputs[y0, x0]
Ib = inputs[y1, x0]
Ic = inputs[y0, x1]
Id = inputs[y1, x1]
wa = (x1.type_as(x) - x) * (y1.type_as(y) - y)
wb = (x1.type_as(x) - x) * (y - y0.type_as(y))
wc = (x - x0.type_as(x)) * (y1.type_as(y) - y)
wd = (x - x0.type_as(x)) * (y - y0.type_as(y))
ans = torch.t((torch.t(Ia) * wa)) + torch.t(torch.t(Ib) * wb) + torch.t(
torch.t(Ic) * wc) + torch.t(torch.t(Id) * wd)
return ans
@MODELS.register_module()
class VoxelSetAbstraction(BaseModule):
"""Voxel set abstraction module for PVRCNN and PVRCNN++.
Args:
num_keypoints (int): The number of key points sampled from
raw points cloud.
fused_out_channel (int): Key points feature output channels
num after fused. Default to 128.
voxel_size (list[float]): Size of voxels. Defaults to
[0.05, 0.05, 0.1].
point_cloud_range (list[float]): Point cloud range. Defaults to
[0, -40, -3, 70.4, 40, 1].
voxel_sa_cfgs_list (List[dict or ConfigDict], optional): List of SA
module cfg. Used to gather key points features from multi-wise
voxel features. Default to None.
rawpoints_sa_cfgs (dict or ConfigDict, optional): SA module cfg.
Used to gather key points features from raw points. Default to
None.
bev_feat_channel (int): Bev features channels num.
Default to 256.
bev_scale_factor (int): Bev features scale factor. Default to 8.
voxel_center_as_source (bool): Whether used voxel centers as points
cloud key points. Defaults to False.
norm_cfg (dict[str]): Config of normalization layer. Default
used dict(type='BN1d', eps=1e-5, momentum=0.1).
bias (bool | str, optional): If specified as `auto`, it will be
decided by `norm_cfg`. `bias` will be set as True if
`norm_cfg` is None, otherwise False. Default: 'auto'.
"""
def __init__(self,
num_keypoints: int,
fused_out_channel: int = 128,
voxel_size: list = [0.05, 0.05, 0.1],
point_cloud_range: list = [0, -40, -3, 70.4, 40, 1],
voxel_sa_cfgs_list: Optional[list] = None,
rawpoints_sa_cfgs: Optional[dict] = None,
bev_feat_channel: int = 256,
bev_scale_factor: int = 8,
voxel_center_as_source: bool = False,
norm_cfg: dict = dict(type='BN2d', eps=1e-5, momentum=0.1),
bias: str = 'auto') -> None:
super().__init__()
self.num_keypoints = num_keypoints
self.fused_out_channel = fused_out_channel
self.voxel_size = voxel_size
self.point_cloud_range = point_cloud_range
self.voxel_center_as_source = voxel_center_as_source
gathered_channel = 0
if rawpoints_sa_cfgs is not None:
self.rawpoints_sa_layer = MODELS.build(rawpoints_sa_cfgs)
gathered_channel += sum(
[x[-1] for x in rawpoints_sa_cfgs.mlp_channels])
else:
self.rawpoints_sa_layer = None
if voxel_sa_cfgs_list is not None:
self.voxel_sa_configs_list = voxel_sa_cfgs_list
self.voxel_sa_layers = nn.ModuleList()
for voxel_sa_config in voxel_sa_cfgs_list:
cur_layer = MODELS.build(voxel_sa_config)
self.voxel_sa_layers.append(cur_layer)
gathered_channel += sum(
[x[-1] for x in voxel_sa_config.mlp_channels])
else:
self.voxel_sa_layers = None
if bev_feat_channel is not None and bev_scale_factor is not None:
self.bev_cfg = mmengine.Config(
dict(
bev_feat_channels=bev_feat_channel,
bev_scale_factor=bev_scale_factor))
gathered_channel += bev_feat_channel
else:
self.bev_cfg = None
self.point_feature_fusion_layer = nn.Sequential(
ConvModule(
gathered_channel,
fused_out_channel,
kernel_size=(1, 1),
stride=(1, 1),
conv_cfg=dict(type='Conv2d'),
norm_cfg=norm_cfg,
bias=bias))
def interpolate_from_bev_features(self, keypoints: torch.Tensor,
bev_features: torch.Tensor,
batch_size: int,
bev_scale_factor: int) -> torch.Tensor:
"""Gather key points features from bev feature map by interpolate.
Args:
keypoints (torch.Tensor): Sampled key points with shape
(N1 + N2 + ..., NDim).
bev_features (torch.Tensor): Bev feature map from the first
stage with shape (B, C, H, W).
batch_size (int): Input batch size.
bev_scale_factor (int): Bev feature map scale factor.
Returns:
torch.Tensor: Key points features gather from bev feature
map with shape (N1 + N2 + ..., C)
"""
x_idxs = (keypoints[..., 0] -
self.point_cloud_range[0]) / self.voxel_size[0]
y_idxs = (keypoints[..., 1] -
self.point_cloud_range[1]) / self.voxel_size[1]
x_idxs = x_idxs / bev_scale_factor
y_idxs = y_idxs / bev_scale_factor
point_bev_features_list = []
for k in range(batch_size):
cur_x_idxs = x_idxs[k, ...]
cur_y_idxs = y_idxs[k, ...]
cur_bev_features = bev_features[k].permute(1, 2, 0) # (H, W, C)
point_bev_features = bilinear_interpolate_torch(
cur_bev_features, cur_x_idxs, cur_y_idxs)
point_bev_features_list.append(point_bev_features)
point_bev_features = torch.cat(
point_bev_features_list, dim=0) # (N1 + N2 + ..., C)
return point_bev_features.view(batch_size, keypoints.shape[1], -1)
def get_voxel_centers(self, coors: torch.Tensor,
scale_factor: float) -> torch.Tensor:
"""Get voxel centers coordinate.
Args:
coors (torch.Tensor): Coordinates of voxels shape is Nx(1+NDim),
where 1 represents the batch index.
scale_factor (float): Scale factor.
Returns:
torch.Tensor: Voxel centers coordinate with shape (N, 3).
"""
assert coors.shape[1] == 4
voxel_centers = coors[:, [3, 2, 1]].float() # (xyz)
voxel_size = torch.tensor(
self.voxel_size,
device=voxel_centers.device).float() * scale_factor
pc_range = torch.tensor(
self.point_cloud_range[0:3], device=voxel_centers.device).float()
voxel_centers = (voxel_centers + 0.5) * voxel_size + pc_range
return voxel_centers
def sample_key_points(self, points: List[torch.Tensor],
coors: torch.Tensor) -> torch.Tensor:
"""Sample key points from raw points cloud.
Args:
points (List[torch.Tensor]): Point cloud of each sample.
coors (torch.Tensor): Coordinates of voxels shape is Nx(1+NDim),
where 1 represents the batch index.
Returns:
torch.Tensor: (B, M, 3) Key points of each sample.
M is num_keypoints.
"""
assert points is not None or coors is not None
if self.voxel_center_as_source:
_src_points = self.get_voxel_centers(coors=coors, scale_factor=1)
batch_size = coors[-1, 0].item() + 1
src_points = [
_src_points[coors[:, 0] == b] for b in range(batch_size)
]
else:
src_points = [p[..., :3] for p in points]
keypoints_list = []
for points_to_sample in src_points:
num_points = points_to_sample.shape[0]
cur_pt_idxs = furthest_point_sample(
points_to_sample.unsqueeze(dim=0).contiguous(),
self.num_keypoints).long()[0]
if num_points < self.num_keypoints:
times = int(self.num_keypoints / num_points) + 1
non_empty = cur_pt_idxs[:num_points]
cur_pt_idxs = non_empty.repeat(times)[:self.num_keypoints]
keypoints = points_to_sample[cur_pt_idxs]
keypoints_list.append(keypoints)
keypoints = torch.stack(keypoints_list, dim=0) # (B, M, 3)
return keypoints
def forward(self, batch_inputs_dict: dict, feats_dict: dict,
rpn_results_list: InstanceList) -> dict:
"""Extract point-wise features from multi-input.
Args:
batch_inputs_dict (dict): The model input dict which include
'points', 'voxels' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- voxels (dict[torch.Tensor]): Voxels of the batch sample.
feats_dict (dict): Contains features from the first
stage.
rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head.
Returns:
dict: Contain Point-wise features, include:
- keypoints (torch.Tensor): Sampled key points.
- keypoint_features (torch.Tensor): Gathered key points
features from multi input.
- fusion_keypoint_features (torch.Tensor): Fusion
keypoint_features by point_feature_fusion_layer.
"""
points = batch_inputs_dict['points']
voxel_encode_features = feats_dict['multi_scale_3d_feats']
bev_encode_features = feats_dict['spatial_feats']
if self.voxel_center_as_source:
voxels_coors = batch_inputs_dict['voxels']['coors']
else:
voxels_coors = None
keypoints = self.sample_key_points(points, voxels_coors)
point_features_list = []
batch_size = len(points)
if self.bev_cfg is not None:
point_bev_features = self.interpolate_from_bev_features(
keypoints, bev_encode_features, batch_size,
self.bev_cfg.bev_scale_factor)
point_features_list.append(point_bev_features.contiguous())
batch_size, num_keypoints, _ = keypoints.shape
key_xyz = keypoints.view(-1, 3)
key_xyz_batch_cnt = key_xyz.new_zeros(batch_size).int().fill_(
num_keypoints)
if self.rawpoints_sa_layer is not None:
batch_points = torch.cat(points, dim=0)
batch_cnt = [len(p) for p in points]
xyz = batch_points[:, :3].contiguous()
features = None
if batch_points.size(1) > 0:
features = batch_points[:, 3:].contiguous()
xyz_batch_cnt = xyz.new_tensor(batch_cnt, dtype=torch.int32)
pooled_points, pooled_features = self.rawpoints_sa_layer(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=key_xyz.contiguous(),
new_xyz_batch_cnt=key_xyz_batch_cnt,
features=features.contiguous(),
)
point_features_list.append(pooled_features.contiguous().view(
batch_size, num_keypoints, -1))
if self.voxel_sa_layers is not None:
for k, voxel_sa_layer in enumerate(self.voxel_sa_layers):
cur_coords = voxel_encode_features[k].indices
xyz = self.get_voxel_centers(
coors=cur_coords,
scale_factor=self.voxel_sa_configs_list[k].scale_factor
).contiguous()
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (cur_coords[:, 0] == bs_idx).sum()
pooled_points, pooled_features = voxel_sa_layer(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=key_xyz.contiguous(),
new_xyz_batch_cnt=key_xyz_batch_cnt,
features=voxel_encode_features[k].features.contiguous(),
)
point_features_list.append(pooled_features.contiguous().view(
batch_size, num_keypoints, -1))
point_features = torch.cat(
point_features_list, dim=-1).view(batch_size * num_keypoints, -1,
1)
fusion_point_features = self.point_feature_fusion_layer(
point_features.unsqueeze(dim=-1)).squeeze(dim=-1)
batch_idxs = torch.arange(
batch_size * num_keypoints, device=keypoints.device
) // num_keypoints # batch indexes of each key points
batch_keypoints_xyz = torch.cat(
(batch_idxs.to(key_xyz.dtype).unsqueeze(dim=-1), key_xyz), dim=-1)
return dict(
keypoint_features=point_features.squeeze(dim=-1),
fusion_keypoint_features=fusion_point_features.squeeze(dim=-1),
keypoints=batch_keypoints_xyz)
...@@ -5,10 +5,11 @@ from .h3d_roi_head import H3DRoIHead ...@@ -5,10 +5,11 @@ from .h3d_roi_head import H3DRoIHead
from .mask_heads import PointwiseSemanticHead, PrimitiveHead from .mask_heads import PointwiseSemanticHead, PrimitiveHead
from .part_aggregation_roi_head import PartAggregationROIHead from .part_aggregation_roi_head import PartAggregationROIHead
from .point_rcnn_roi_head import PointRCNNRoIHead from .point_rcnn_roi_head import PointRCNNRoIHead
from .pv_rcnn_roi_head import PVRCNNRoiHead
from .roi_extractors import Single3DRoIAwareExtractor, SingleRoIExtractor from .roi_extractors import Single3DRoIAwareExtractor, SingleRoIExtractor
__all__ = [ __all__ = [
'Base3DRoIHead', 'PartAggregationROIHead', 'PointwiseSemanticHead', 'Base3DRoIHead', 'PartAggregationROIHead', 'PointwiseSemanticHead',
'Single3DRoIAwareExtractor', 'PartA2BboxHead', 'SingleRoIExtractor', 'Single3DRoIAwareExtractor', 'PartA2BboxHead', 'SingleRoIExtractor',
'H3DRoIHead', 'PrimitiveHead', 'PointRCNNRoIHead' 'H3DRoIHead', 'PrimitiveHead', 'PointRCNNRoIHead', 'PVRCNNRoiHead'
] ]
...@@ -7,9 +7,10 @@ from mmdet.models.roi_heads.bbox_heads import (BBoxHead, ConvFCBBoxHead, ...@@ -7,9 +7,10 @@ from mmdet.models.roi_heads.bbox_heads import (BBoxHead, ConvFCBBoxHead,
from .h3d_bbox_head import H3DBboxHead from .h3d_bbox_head import H3DBboxHead
from .parta2_bbox_head import PartA2BboxHead from .parta2_bbox_head import PartA2BboxHead
from .point_rcnn_bbox_head import PointRCNNBboxHead from .point_rcnn_bbox_head import PointRCNNBboxHead
from .pv_rcnn_bbox_head import PVRCNNBBoxHead
__all__ = [ __all__ = [
'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead', 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead',
'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'PartA2BboxHead', 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'PartA2BboxHead',
'H3DBboxHead', 'PointRCNNBboxHead' 'H3DBboxHead', 'PointRCNNBboxHead', 'PVRCNNBBoxHead'
] ]
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .foreground_segmentation_head import ForegroundSegmentationHead
from .pointwise_semantic_head import PointwiseSemanticHead from .pointwise_semantic_head import PointwiseSemanticHead
from .primitive_head import PrimitiveHead from .primitive_head import PrimitiveHead
__all__ = ['PointwiseSemanticHead', 'PrimitiveHead'] __all__ = [
'PointwiseSemanticHead', 'PrimitiveHead', 'ForegroundSegmentationHead'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple
import torch
from mmcv.cnn.bricks import build_norm_layer
from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch import nn as nn
from mmdet3d.models.builder import build_loss
from mmdet3d.registry import MODELS
from mmdet3d.utils import InstanceList
from mmdet.models.utils import multi_apply
@MODELS.register_module()
class ForegroundSegmentationHead(BaseModule):
"""Foreground segmentation head.
Args:
in_channels (int): The number of input channel.
mlp_channels (tuple[int]): Specify of mlp channels. Defaults
to (256, 256).
extra_width (float): Boxes enlarge width. Default used 0.1.
norm_cfg (dict): Type of normalization method. Defaults to
dict(type='BN1d', eps=1e-5, momentum=0.1).
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
loss_seg (dict): Config of segmentation loss. Defaults to
dict(type='mmdet.FocalLoss')
"""
def __init__(
self,
in_channels: int,
mlp_channels: Tuple[int] = (256, 256),
extra_width: float = 0.1,
norm_cfg: dict = dict(type='BN1d', eps=1e-5, momentum=0.1),
init_cfg: Optional[dict] = None,
loss_seg: dict = dict(
type='mmdet.FocalLoss',
use_sigmoid=True,
reduction='sum',
gamma=2.0,
alpha=0.25,
activated=True,
loss_weight=1.0)
) -> None:
super(ForegroundSegmentationHead, self).__init__(init_cfg=init_cfg)
self.extra_width = extra_width
self.num_classes = 1
self.in_channels = in_channels
self.use_sigmoid_cls = loss_seg.get('use_sigmoid', False)
out_channels = 1
if self.use_sigmoid_cls:
self.out_channels = out_channels
else:
self.out_channels = out_channels + 1
mlps_layers = []
cin = in_channels
for mlp in mlp_channels:
mlps_layers.extend([
nn.Linear(cin, mlp, bias=False),
build_norm_layer(norm_cfg, mlp)[1],
nn.ReLU()
])
cin = mlp
mlps_layers.append(nn.Linear(cin, self.out_channels, bias=True))
self.seg_cls_layer = nn.Sequential(*mlps_layers)
self.loss_seg = build_loss(loss_seg)
def forward(self, feats: torch.Tensor) -> dict:
"""Forward head.
Args:
feats (torch.Tensor): Point-wise features.
Returns:
dict: Segment predictions.
"""
seg_preds = self.seg_cls_layer(feats)
return dict(seg_preds=seg_preds)
def _get_targets_single(self, point_xyz: torch.Tensor,
gt_bboxes_3d: InstanceData,
gt_labels_3d: torch.Tensor) -> torch.Tensor:
"""generate segmentation targets for a single sample.
Args:
point_xyz (torch.Tensor): Coordinate of points.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth boxes in
shape (box_num, 7).
gt_labels_3d (torch.Tensor): Class labels of ground truths in
shape (box_num).
Returns:
torch.Tensor: Points class labels.
"""
point_cls_labels_single = point_xyz.new_zeros(
point_xyz.shape[0]).long()
enlarged_gt_boxes = gt_bboxes_3d.enlarged_box(self.extra_width)
box_idxs_of_pts = gt_bboxes_3d.points_in_boxes_part(point_xyz).long()
extend_box_idxs_of_pts = enlarged_gt_boxes.points_in_boxes_part(
point_xyz).long()
box_fg_flag = box_idxs_of_pts >= 0
fg_flag = box_fg_flag.clone()
ignore_flag = fg_flag ^ (extend_box_idxs_of_pts >= 0)
point_cls_labels_single[ignore_flag] = -1
gt_box_of_fg_points = gt_labels_3d[box_idxs_of_pts[fg_flag]]
point_cls_labels_single[
fg_flag] = 1 if self.num_classes == 1 else\
gt_box_of_fg_points.long()
return point_cls_labels_single,
def get_targets(self, points_bxyz: torch.Tensor,
batch_gt_instances_3d: InstanceList) -> dict:
"""Generate segmentation targets.
Args:
points_bxyz (torch.Tensor): The coordinates of point in shape
(B, num_points, 3).
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
Returns:
dict: Prediction targets
- seg_targets (torch.Tensor): Segmentation targets.
"""
batch_size = len(batch_gt_instances_3d)
points_xyz_list = []
gt_bboxes_3d = []
gt_labels_3d = []
for idx in range(batch_size):
coords_idx = points_bxyz[:, 0] == idx
points_xyz_list.append(points_bxyz[coords_idx][..., 1:])
gt_bboxes_3d.append(batch_gt_instances_3d[idx].bboxes_3d)
gt_labels_3d.append(batch_gt_instances_3d[idx].labels_3d)
seg_targets, = multi_apply(self._get_targets_single, points_xyz_list,
gt_bboxes_3d, gt_labels_3d)
seg_targets = torch.cat(seg_targets, dim=0)
return dict(seg_targets=seg_targets)
def loss(self, semantic_results: dict,
semantic_targets: dict) -> Dict[str, torch.Tensor]:
"""Calculate point-wise segmentation losses.
Args:
semantic_results (dict): Results from semantic head.
semantic_targets (dict): Targets of semantic results.
Returns:
dict: Loss of segmentation.
- loss_semantic (torch.Tensor): Segmentation prediction loss.
"""
seg_preds = semantic_results['seg_preds']
seg_targets = semantic_targets['seg_targets']
positives = (seg_targets > 0)
negative_cls_weights = (seg_targets == 0).float()
seg_weights = (negative_cls_weights + 1.0 * positives).float()
pos_normalizer = positives.sum(dim=0).float()
seg_weights /= torch.clamp(pos_normalizer, min=1.0)
seg_preds = torch.sigmoid(seg_preds)
loss_seg = self.loss_seg(seg_preds, (~positives).long(), seg_weights)
return dict(loss_semantic=loss_seg)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import torch
from torch.nn import functional as F
from mmdet3d.models.roi_heads.base_3droi_head import Base3DRoIHead
from mmdet3d.registry import MODELS
from mmdet3d.structures import bbox3d2roi
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import InstanceList
from mmdet.models.task_modules import AssignResult
from mmdet.models.task_modules.samplers import SamplingResult
@MODELS.register_module()
class PVRCNNRoiHead(Base3DRoIHead):
"""RoI head for PV-RCNN.
Args:
num_classes (int): The number of classes. Defaults to 3.
semantic_head (dict, optional): Config of semantic head.
Defaults to None.
bbox_roi_extractor (dict, optional): Config of roi_extractor.
Defaults to None.
bbox_head (dict, optional): Config of bbox_head. Defaults to None.
train_cfg (dict, optional): Train config of model.
Defaults to None.
test_cfg (dict, optional): Train config of model.
Defaults to None.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
"""
def __init__(self,
num_classes: int = 3,
semantic_head: Optional[dict] = None,
bbox_roi_extractor: Optional[dict] = None,
bbox_head: Optional[dict] = None,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None):
super(PVRCNNRoiHead, self).__init__(
bbox_head=bbox_head,
bbox_roi_extractor=bbox_roi_extractor,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg)
self.num_classes = num_classes
self.semantic_head = MODELS.build(semantic_head)
self.init_assigner_sampler()
@property
def with_semantic(self):
"""bool: whether the head has semantic branch"""
return hasattr(self,
'semantic_head') and self.semantic_head is not None
def loss(self, feats_dict: dict, rpn_results_list: InstanceList,
batch_data_samples: SampleList, **kwargs) -> dict:
"""Training forward function of PVRCNNROIHead.
Args:
feats_dict (dict): Contains point-wise features.
rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
dict: losses from each head.
- loss_semantic (torch.Tensor): loss of semantic head.
- loss_bbox (torch.Tensor): loss of bboxes.
- loss_cls (torch.Tensor): loss of object classification.
- loss_corner (torch.Tensor): loss of bboxes corners.
"""
losses = dict()
batch_gt_instances_3d = []
batch_gt_instances_ignore = []
for data_sample in batch_data_samples:
batch_gt_instances_3d.append(data_sample.gt_instances_3d)
if 'ignored_instances' in data_sample:
batch_gt_instances_ignore.append(data_sample.ignored_instances)
else:
batch_gt_instances_ignore.append(None)
if self.with_semantic:
semantic_results = self._semantic_forward_train(
feats_dict['keypoint_features'], feats_dict['keypoints'],
batch_gt_instances_3d)
losses['loss_semantic'] = semantic_results['loss_semantic']
sample_results = self._assign_and_sample(rpn_results_list,
batch_gt_instances_3d)
if self.with_bbox:
bbox_results = self._bbox_forward_train(
semantic_results['seg_preds'],
feats_dict['fusion_keypoint_features'],
feats_dict['keypoints'], sample_results)
losses.update(bbox_results['loss_bbox'])
return losses
def predict(self, feats_dict: dict, rpn_results_list: InstanceList,
batch_data_samples: SampleList, **kwargs) -> SampleList:
"""Perform forward propagation of the roi head and predict detection
results on the features of the upstream network.
Args:
feats_dict (dict): Contains point-wise features.
rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
list[:obj:`InstanceData`]: Detection results of each sample
after the post process.
Each item usually contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes,
contains a tensor with shape (num_instances, C), where
C >= 7.
"""
assert self.with_bbox, 'Bbox head must be implemented.'
assert self.with_semantic, 'Semantic head must be implemented.'
batch_input_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
semantic_results = self.semantic_head(feats_dict['keypoint_features'])
point_features = feats_dict[
'fusion_keypoint_features'] * semantic_results[
'seg_preds'].sigmoid().max(
dim=-1, keepdim=True).values
rois = bbox3d2roi(
[res['bboxes_3d'].tensor for res in rpn_results_list])
labels_3d = [res['labels_3d'] for res in rpn_results_list]
bbox_results = self._bbox_forward(point_features,
feats_dict['keypoints'], rois)
results_list = self.bbox_head.get_results(rois,
bbox_results['bbox_scores'],
bbox_results['bbox_reg'],
labels_3d, batch_input_metas,
self.test_cfg)
return results_list
def _bbox_forward_train(self, seg_preds: torch.Tensor,
keypoint_features: torch.Tensor,
keypoints: torch.Tensor,
sampling_results: SamplingResult) -> dict:
"""Forward training function of roi_extractor and bbox_head.
Args:
seg_preds (torch.Tensor): Point-wise semantic features.
keypoint_features (torch.Tensor): key points features
from points encoder.
keypoints (torch.Tensor): Coordinate of key points.
sampling_results (:obj:`SamplingResult`): Sampled results used
for training.
Returns:
dict: Forward results including losses and predictions.
"""
rois = bbox3d2roi([res.bboxes for res in sampling_results])
keypoint_features = keypoint_features * seg_preds.sigmoid().max(
dim=-1, keepdim=True).values
bbox_results = self._bbox_forward(keypoint_features, keypoints, rois)
bbox_targets = self.bbox_head.get_targets(sampling_results,
self.train_cfg)
loss_bbox = self.bbox_head.loss(bbox_results['bbox_scores'],
bbox_results['bbox_reg'], rois,
*bbox_targets)
bbox_results.update(loss_bbox=loss_bbox)
return bbox_results
def _bbox_forward(self, keypoint_features: torch.Tensor,
keypoints: torch.Tensor, rois: torch.Tensor) -> dict:
"""Forward function of roi_extractor and bbox_head used in both
training and testing.
Args:
rois (Tensor): Roi boxes.
keypoint_features (torch.Tensor): key points features
from points encoder.
keypoints (torch.Tensor): Coordinate of key points.
rois (Tensor): Roi boxes.
Returns:
dict: Contains predictions of bbox_head and
features of roi_extractor.
"""
pooled_keypoint_features = self.bbox_roi_extractor(
keypoint_features, keypoints[..., 1:], keypoints[..., 0].int(),
rois)
bbox_score, bbox_reg = self.bbox_head(pooled_keypoint_features)
bbox_results = dict(bbox_scores=bbox_score, bbox_reg=bbox_reg)
return bbox_results
def _assign_and_sample(
self, proposal_list: InstanceList,
batch_gt_instances_3d: InstanceList) -> List[SamplingResult]:
"""Assign and sample proposals for training.
Args:
proposal_list (list[:obj:`InstancesData`]): Proposals produced by
rpn head.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
Returns:
list[:obj:`SamplingResult`]: Sampled results of each training
sample.
"""
sampling_results = []
# bbox assign
for batch_idx in range(len(proposal_list)):
cur_proposal_list = proposal_list[batch_idx]
cur_boxes = cur_proposal_list['bboxes_3d']
cur_labels_3d = cur_proposal_list['labels_3d']
cur_gt_instances_3d = batch_gt_instances_3d[batch_idx]
cur_gt_instances_3d.bboxes_3d = cur_gt_instances_3d.\
bboxes_3d.tensor
cur_gt_bboxes = batch_gt_instances_3d[batch_idx].bboxes_3d.to(
cur_boxes.device)
cur_gt_labels = batch_gt_instances_3d[batch_idx].labels_3d
batch_num_gts = 0
# 0 is bg
batch_gt_indis = cur_gt_labels.new_full((len(cur_boxes), ), 0)
batch_max_overlaps = cur_boxes.tensor.new_zeros(len(cur_boxes))
# -1 is bg
batch_gt_labels = cur_gt_labels.new_full((len(cur_boxes), ), -1)
# each class may have its own assigner
if isinstance(self.bbox_assigner, list):
for i, assigner in enumerate(self.bbox_assigner):
gt_per_cls = (cur_gt_labels == i)
pred_per_cls = (cur_labels_3d == i)
cur_assign_res = assigner.assign(
cur_proposal_list[pred_per_cls],
cur_gt_instances_3d[gt_per_cls])
# gather assign_results in different class into one result
batch_num_gts += cur_assign_res.num_gts
# gt inds (1-based)
gt_inds_arange_pad = gt_per_cls.nonzero(
as_tuple=False).view(-1) + 1
# pad 0 for indice unassigned
gt_inds_arange_pad = F.pad(
gt_inds_arange_pad, (1, 0), mode='constant', value=0)
# pad -1 for indice ignore
gt_inds_arange_pad = F.pad(
gt_inds_arange_pad, (1, 0), mode='constant', value=-1)
# convert to 0~gt_num+2 for indices
gt_inds_arange_pad += 1
# now 0 is bg, >1 is fg in batch_gt_indis
batch_gt_indis[pred_per_cls] = gt_inds_arange_pad[
cur_assign_res.gt_inds + 1] - 1
batch_max_overlaps[
pred_per_cls] = cur_assign_res.max_overlaps
batch_gt_labels[pred_per_cls] = cur_assign_res.labels
assign_result = AssignResult(batch_num_gts, batch_gt_indis,
batch_max_overlaps,
batch_gt_labels)
else: # for single class
assign_result = self.bbox_assigner.assign(
cur_proposal_list, cur_gt_instances_3d)
# sample boxes
sampling_result = self.bbox_sampler.sample(assign_result,
cur_boxes.tensor,
cur_gt_bboxes,
cur_gt_labels)
sampling_results.append(sampling_result)
return sampling_results
def _semantic_forward_train(self, keypoint_features: torch.Tensor,
keypoints: torch.Tensor,
batch_gt_instances_3d: InstanceList) -> dict:
"""Train semantic head.
Args:
keypoint_features (torch.Tensor): key points features
from points encoder.
keypoints (torch.Tensor): Coordinate of key points.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
Returns:
dict: Segmentation results including losses
"""
semantic_results = self.semantic_head(keypoint_features)
semantic_targets = self.semantic_head.get_targets(
keypoints, batch_gt_instances_3d)
loss_semantic = self.semantic_head.loss(semantic_results,
semantic_targets)
semantic_results.update(loss_semantic)
return semantic_results
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models.roi_heads.roi_extractors import SingleRoIExtractor from mmdet.models.roi_heads.roi_extractors import SingleRoIExtractor
from .batch_roigridpoint_extractor import Batch3DRoIGridExtractor
from .single_roiaware_extractor import Single3DRoIAwareExtractor from .single_roiaware_extractor import Single3DRoIAwareExtractor
from .single_roipoint_extractor import Single3DRoIPointExtractor from .single_roipoint_extractor import Single3DRoIPointExtractor
__all__ = [ __all__ = [
'SingleRoIExtractor', 'Single3DRoIAwareExtractor', 'SingleRoIExtractor', 'Single3DRoIAwareExtractor',
'Single3DRoIPointExtractor' 'Single3DRoIPointExtractor', 'Batch3DRoIGridExtractor'
] ]
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.model import BaseModule
from mmdet3d.registry import MODELS
from mmdet3d.structures.bbox_3d import rotation_3d_in_axis
@MODELS.register_module()
class Batch3DRoIGridExtractor(BaseModule):
"""Grid point wise roi-aware Extractor.
Args:
grid_size (int): The number of grid points in a roi bbox.
Defaults to 6.
roi_layer (dict, optional): Config of sa module to get
grid points features. Defaults to None.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
"""
def __init__(self,
grid_size: int = 6,
roi_layer: dict = None,
init_cfg: dict = None) -> None:
super(Batch3DRoIGridExtractor, self).__init__(init_cfg=init_cfg)
self.roi_grid_pool_layer = MODELS.build(roi_layer)
self.grid_size = grid_size
def forward(self, feats: torch.Tensor, coordinate: torch.Tensor,
batch_inds: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
"""Forward roi extractor to extract grid points feature.
Args:
feats (torch.Tensor): Key points features.
coordinate (torch.Tensor): Key points coordinates.
batch_inds (torch.Tensor): Input batch indexes.
rois (torch.Tensor): Detection results of rpn head.
Returns:
torch.Tensor: Grid points features.
"""
batch_size = int(batch_inds.max()) + 1
xyz = coordinate
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for k in range(batch_size):
xyz_batch_cnt[k] = (batch_inds == k).sum()
rois_batch_inds = rois[:, 0].int()
# (N1+N2+..., 6x6x6, 3)
roi_grid = self.get_dense_grid_points(rois[:, 1:])
new_xyz = roi_grid.view(-1, 3)
new_xyz_batch_cnt = new_xyz.new_zeros(batch_size).int()
for k in range(batch_size):
new_xyz_batch_cnt[k] = ((rois_batch_inds == k).sum() *
roi_grid.size(1))
pooled_points, pooled_features = self.roi_grid_pool_layer(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz.contiguous(),
new_xyz_batch_cnt=new_xyz_batch_cnt,
features=feats.contiguous()) # (M1 + M2 ..., C)
pooled_features = pooled_features.view(-1, self.grid_size,
self.grid_size, self.grid_size,
pooled_features.shape[-1])
# (BxN, 6, 6, 6, C)
return pooled_features
def get_dense_grid_points(self, rois: torch.Tensor) -> torch.Tensor:
"""Get dense grid points from rois.
Args:
rois (torch.Tensor): Detection results of rpn head.
Returns:
torch.Tensor: Grid points coordinates.
"""
rois_bbox = rois.clone()
rois_bbox[:, 2] += rois_bbox[:, 5] / 2
faked_features = rois_bbox.new_ones(
(self.grid_size, self.grid_size, self.grid_size))
dense_idx = faked_features.nonzero()
dense_idx = dense_idx.repeat(rois_bbox.size(0), 1, 1).float()
dense_idx = ((dense_idx + 0.5) / self.grid_size)
dense_idx[..., :3] -= 0.5
roi_ctr = rois_bbox[:, :3]
roi_dim = rois_bbox[:, 3:6]
roi_grid_points = dense_idx * roi_dim.view(-1, 1, 3)
roi_grid_points = rotation_3d_in_axis(
roi_grid_points, rois_bbox[:, 6], axis=2)
roi_grid_points += roi_ctr.view(-1, 1, 3)
return roi_grid_points
import unittest
import torch
from mmengine import DefaultScope
from mmdet3d.registry import MODELS
from tests.utils.model_utils import (_create_detector_inputs,
_get_detector_cfg, _setup_seed)
class TestPVRCNN(unittest.TestCase):
def test_pvrcnn(self):
import mmdet3d.models
assert hasattr(mmdet3d.models, 'PointVoxelRCNN')
DefaultScope.get_instance('test_pvrcnn', scope_name='mmdet3d')
_setup_seed(0)
pvrcnn_cfg = _get_detector_cfg(
'pvrcnn/pvrcnn_8xb2-80e_kitti-3d-3class.py')
model = MODELS.build(pvrcnn_cfg)
num_gt_instance = 2
packed_inputs = _create_detector_inputs(
num_gt_instance=num_gt_instance)
# TODO: Support aug data test
# aug_packed_inputs = [
# _create_detector_inputs(num_gt_instance=num_gt_instance),
# _create_detector_inputs(num_gt_instance=num_gt_instance + 1)
# ]
# test_aug_test
# metainfo = {
# 'pcd_scale_factor': 1,
# 'pcd_horizontal_flip': 1,
# 'pcd_vertical_flip': 1,
# 'box_type_3d': LiDARInstance3DBoxes
# }
# for item in aug_packed_inputs:
# for batch_id in len(item['data_samples']):
# item['data_samples'][batch_id].set_metainfo(metainfo)
if torch.cuda.is_available():
model = model.cuda()
# test simple_test
with torch.no_grad():
data = model.data_preprocessor(packed_inputs, True)
torch.cuda.empty_cache()
results = model.forward(**data, mode='predict')
self.assertEqual(len(results), 1)
self.assertIn('bboxes_3d', results[0].pred_instances_3d)
self.assertIn('scores_3d', results[0].pred_instances_3d)
self.assertIn('labels_3d', results[0].pred_instances_3d)
# save the memory
with torch.no_grad():
losses = model.forward(**data, mode='loss')
torch.cuda.empty_cache()
self.assertGreater(losses['loss_rpn_cls'][0], 0)
self.assertGreaterEqual(losses['loss_rpn_bbox'][0], 0)
self.assertGreaterEqual(losses['loss_rpn_dir'][0], 0)
self.assertGreater(losses['loss_semantic'], 0)
self.assertGreaterEqual(losses['loss_bbox'], 0)
self.assertGreaterEqual(losses['loss_cls'], 0)
self.assertGreaterEqual(losses['loss_corner'], 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment