Unverified Commit d8c9bc66 authored by VVsssssk's avatar VVsssssk Committed by GitHub
Browse files

[Fix] Refactor Point RCNN and fix it. (#1819)



* add deploy.yaml

* fix

* fix

* fic

* fix

* fix

* fix

* fix

* fix

* fix

* fix bug

* fix comments

* fix comments

* fix

* fix

* Minor fix
Co-authored-by: default avatarTai-Wang <tab_wang@outlook.com>
parent ea131ebc
model = dict( model = dict(
type='PointRCNN', type='PointRCNN',
data_preprocessor=dict(type='Det3DDataPreprocessor'),
backbone=dict( backbone=dict(
type='PointNet2SAMSG', type='PointNet2SAMSG',
in_channels=4, in_channels=4,
...@@ -34,14 +35,14 @@ model = dict( ...@@ -34,14 +35,14 @@ model = dict(
cls_linear_channels=(256, 256), cls_linear_channels=(256, 256),
reg_linear_channels=(256, 256)), reg_linear_channels=(256, 256)),
cls_loss=dict( cls_loss=dict(
type='FocalLoss', type='mmdet.FocalLoss',
use_sigmoid=True, use_sigmoid=True,
reduction='sum', reduction='sum',
gamma=2.0, gamma=2.0,
alpha=0.25, alpha=0.25,
loss_weight=1.0), loss_weight=1.0),
bbox_loss=dict( bbox_loss=dict(
type='SmoothL1Loss', type='mmdet.SmoothL1Loss',
beta=1.0 / 9.0, beta=1.0 / 9.0,
reduction='sum', reduction='sum',
loss_weight=1.0), loss_weight=1.0),
...@@ -55,12 +56,22 @@ model = dict( ...@@ -55,12 +56,22 @@ model = dict(
1.73]])), 1.73]])),
roi_head=dict( roi_head=dict(
type='PointRCNNRoIHead', type='PointRCNNRoIHead',
point_roi_extractor=dict( bbox_roi_extractor=dict(
type='Single3DRoIPointExtractor', type='Single3DRoIPointExtractor',
roi_layer=dict(type='RoIPointPool3d', num_sampled_points=512)), roi_layer=dict(type='RoIPointPool3d', num_sampled_points=512)),
bbox_head=dict( bbox_head=dict(
type='PointRCNNBboxHead', type='PointRCNNBboxHead',
num_classes=1, num_classes=1,
loss_bbox=dict(
type='mmdet.SmoothL1Loss',
beta=1.0 / 9.0,
reduction='sum',
loss_weight=1.0),
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
pred_layer_cfg=dict( pred_layer_cfg=dict(
in_channels=512, in_channels=512,
cls_conv_channels=(256, 256), cls_conv_channels=(256, 256),
...@@ -79,13 +90,16 @@ model = dict( ...@@ -79,13 +90,16 @@ model = dict(
train_cfg=dict( train_cfg=dict(
pos_distance_thr=10.0, pos_distance_thr=10.0,
rpn=dict( rpn=dict(
nms_cfg=dict( rpn_proposal=dict(
use_rotate_nms=True, iou_thr=0.8, nms_pre=9000, nms_post=512), use_rotate_nms=True,
score_thr=None), score_thr=None,
iou_thr=0.8,
nms_pre=9000,
nms_post=512)),
rcnn=dict( rcnn=dict(
assigner=[ assigner=[
dict( # for Car dict( # for Pedestrian
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict( iou_calculator=dict(
type='BboxOverlaps3D', coordinate='lidar'), type='BboxOverlaps3D', coordinate='lidar'),
pos_iou_thr=0.55, pos_iou_thr=0.55,
...@@ -93,8 +107,8 @@ model = dict( ...@@ -93,8 +107,8 @@ model = dict(
min_pos_iou=0.55, min_pos_iou=0.55,
ignore_iof_thr=-1, ignore_iof_thr=-1,
match_low_quality=False), match_low_quality=False),
dict( # for Pedestrian dict( # for Cyclist
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict( iou_calculator=dict(
type='BboxOverlaps3D', coordinate='lidar'), type='BboxOverlaps3D', coordinate='lidar'),
pos_iou_thr=0.55, pos_iou_thr=0.55,
...@@ -102,8 +116,8 @@ model = dict( ...@@ -102,8 +116,8 @@ model = dict(
min_pos_iou=0.55, min_pos_iou=0.55,
ignore_iof_thr=-1, ignore_iof_thr=-1,
match_low_quality=False), match_low_quality=False),
dict( # for Cyclist dict( # for Car
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict( iou_calculator=dict(
type='BboxOverlaps3D', coordinate='lidar'), type='BboxOverlaps3D', coordinate='lidar'),
pos_iou_thr=0.55, pos_iou_thr=0.55,
...@@ -126,6 +140,9 @@ model = dict( ...@@ -126,6 +140,9 @@ model = dict(
test_cfg=dict( test_cfg=dict(
rpn=dict( rpn=dict(
nms_cfg=dict( nms_cfg=dict(
use_rotate_nms=True, iou_thr=0.85, nms_pre=9000, nms_post=512), use_rotate_nms=True,
score_thr=None), iou_thr=0.85,
nms_pre=9000,
nms_post=512,
score_thr=None)),
rcnn=dict(use_rotate_nms=True, nms_thr=0.1, score_thr=0.1))) rcnn=dict(use_rotate_nms=True, nms_thr=0.1, score_thr=0.1)))
...@@ -6,7 +6,8 @@ _base_ = [ ...@@ -6,7 +6,8 @@ _base_ = [
# dataset settings # dataset settings
dataset_type = 'KittiDataset' dataset_type = 'KittiDataset'
data_root = 'data/kitti/' data_root = 'data/kitti/'
class_names = ['Car', 'Pedestrian', 'Cyclist'] class_names = ['Pedestrian', 'Cyclist', 'Car']
metainfo = dict(CLASSES=class_names)
point_cloud_range = [0, -40, -3, 70.4, 40, 1] point_cloud_range = [0, -40, -3, 70.4, 40, 1]
input_modality = dict(use_lidar=True, use_camera=False) input_modality = dict(use_lidar=True, use_camera=False)
...@@ -42,8 +43,9 @@ train_pipeline = [ ...@@ -42,8 +43,9 @@ train_pipeline = [
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointSample', num_points=16384, sample_range=40.0), dict(type='PointSample', num_points=16384, sample_range=40.0),
dict(type='PointShuffle'), dict(type='PointShuffle'),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
test_pipeline = [ test_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
...@@ -61,41 +63,67 @@ test_pipeline = [ ...@@ -61,41 +63,67 @@ test_pipeline = [
dict(type='RandomFlip3D'), dict(type='RandomFlip3D'),
dict( dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range), type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointSample', num_points=16384, sample_range=40.0), dict(type='PointSample', num_points=16384, sample_range=40.0)
dict( ]),
type='DefaultFormatBundle3D', dict(type='Pack3DDetInputs', keys=['points'])
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
] ]
train_dataloader = dict(
data = dict( batch_size=2,
samples_per_gpu=2, num_workers=2,
workers_per_gpu=2, dataset=dict(
train=dict(
type='RepeatDataset', type='RepeatDataset',
times=2, times=2,
dataset=dict(pipeline=train_pipeline, classes=class_names)), dataset=dict(pipeline=train_pipeline, metainfo=metainfo)))
val=dict(pipeline=test_pipeline, classes=class_names), test_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo))
test=dict(pipeline=test_pipeline, classes=class_names)) val_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo))
# optimizer
lr = 0.001 # max learning rate lr = 0.001 # max learning rate
optimizer = dict(lr=lr, betas=(0.95, 0.85)) optim_wrapper = dict(optimizer=dict(lr=lr, betas=(0.95, 0.85)))
# runtime settings train_cfg = dict(by_epoch=True, max_epochs=80, val_interval=2)
runner = dict(type='EpochBasedRunner', max_epochs=80)
evaluation = dict(interval=2)
# yapf:disable
log_config = dict(
interval=30,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')
])
# yapf:enable
# Default setting for scaling LR automatically # Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically # - `enable` means enable scaling LR automatically
# or not by default. # or not by default.
# - `base_batch_size` = (8 GPUs) x (2 samples per GPU). # - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=16) auto_scale_lr = dict(enable=False, base_batch_size=16)
param_scheduler = [
# learning rate scheduler
# During the first 35 epochs, learning rate increases from 0 to lr * 10
# during the next 45 epochs, learning rate decreases from lr * 10 to
# lr * 1e-4
dict(
type='CosineAnnealingLR',
T_max=35,
eta_min=lr * 10,
begin=0,
end=35,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=45,
eta_min=lr * 1e-4,
begin=35,
end=80,
by_epoch=True,
convert_to_iter_based=True),
# momentum scheduler
# During the first 35 epochs, momentum increases from 0 to 0.85 / 0.95
# during the next 45 epochs, momentum increases from 0.85 / 0.95 to 1
dict(
type='CosineAnnealingMomentum',
T_max=35,
eta_min=0.85 / 0.95,
begin=0,
end=35,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingMomentum',
T_max=45,
eta_min=1,
begin=35,
end=80,
by_epoch=True,
convert_to_iter_based=True)
]
...@@ -365,7 +365,7 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -365,7 +365,7 @@ class PartAggregationROIHead(Base3DRoIHead):
Args: Args:
feats_dict (dict): Contains features from the first stage. feats_dict (dict): Contains features from the first stage.
rpn_results_list (List[:obj:`InstancesData`]): Detection results rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head. of rpn head.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as samples. It usually includes information such as
...@@ -412,7 +412,7 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -412,7 +412,7 @@ class PartAggregationROIHead(Base3DRoIHead):
voxel_dict (dict): Contains information of voxels. voxel_dict (dict): Contains information of voxels.
batch_input_metas (list[dict], Optional): Batch image meta info. batch_input_metas (list[dict], Optional): Batch image meta info.
Defaults to None. Defaults to None.
rpn_results_list (List[:obj:`InstancesData`]): Detection results rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head. of rpn head.
test_cfg (Config): Test config. test_cfg (Config): Test config.
...@@ -438,7 +438,7 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -438,7 +438,7 @@ class PartAggregationROIHead(Base3DRoIHead):
Args: Args:
feats_dict (dict): Contains features from the first stage. feats_dict (dict): Contains features from the first stage.
rpn_results_list (List[:obj:`InstancesData`]): Detection results rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head. of rpn head.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as samples. It usually includes information such as
......
...@@ -204,7 +204,7 @@ class Base3DDenseHead(BaseModule, metaclass=ABCMeta): ...@@ -204,7 +204,7 @@ class Base3DDenseHead(BaseModule, metaclass=ABCMeta):
score_factors (list[Tensor], optional): Score factor for score_factors (list[Tensor], optional): Score factor for
all scale level, each is a 4D-tensor, has shape all scale level, each is a 4D-tensor, has shape
(batch_size, num_priors * 1, H, W). Defaults to None. (batch_size, num_priors * 1, H, W). Defaults to None.
batch_input_metas (list[dict], Optional): Batch image meta info. batch_input_metas (list[dict], Optional): Batch inputs meta info.
Defaults to None. Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used. configuration, if None, test_cfg would be used.
......
...@@ -183,8 +183,7 @@ class PartA2RPNHead(Anchor3DHead): ...@@ -183,8 +183,7 @@ class PartA2RPNHead(Anchor3DHead):
result = self.class_agnostic_nms(mlvl_bboxes, mlvl_bboxes_for_nms, result = self.class_agnostic_nms(mlvl_bboxes, mlvl_bboxes_for_nms,
mlvl_max_scores, mlvl_label_pred, mlvl_max_scores, mlvl_label_pred,
mlvl_cls_score, mlvl_dir_scores, mlvl_cls_score, mlvl_dir_scores,
score_thr, cfg.nms_post, cfg, score_thr, cfg, input_meta)
input_meta)
return result return result
def loss_and_predict(self, def loss_and_predict(self,
...@@ -275,7 +274,7 @@ class PartA2RPNHead(Anchor3DHead): ...@@ -275,7 +274,7 @@ class PartA2RPNHead(Anchor3DHead):
mlvl_bboxes_for_nms: Tensor, mlvl_bboxes_for_nms: Tensor,
mlvl_max_scores: Tensor, mlvl_label_pred: Tensor, mlvl_max_scores: Tensor, mlvl_label_pred: Tensor,
mlvl_cls_score: Tensor, mlvl_dir_scores: Tensor, mlvl_cls_score: Tensor, mlvl_dir_scores: Tensor,
score_thr: int, max_num: int, cfg: ConfigDict, score_thr: int, cfg: ConfigDict,
input_meta: dict) -> Dict: input_meta: dict) -> Dict:
"""Class agnostic nms for single batch. """Class agnostic nms for single batch.
...@@ -291,7 +290,6 @@ class PartA2RPNHead(Anchor3DHead): ...@@ -291,7 +290,6 @@ class PartA2RPNHead(Anchor3DHead):
mlvl_dir_scores (torch.Tensor): Direction scores of mlvl_dir_scores (torch.Tensor): Direction scores of
Multi-level bbox. Multi-level bbox.
score_thr (int): Score threshold. score_thr (int): Score threshold.
max_num (int): Max number of bboxes after nms.
cfg (:obj:`ConfigDict`): Training or testing config. cfg (:obj:`ConfigDict`): Training or testing config.
input_meta (dict): Contain pcd and img's meta info. input_meta (dict): Contain pcd and img's meta info.
...@@ -339,9 +337,9 @@ class PartA2RPNHead(Anchor3DHead): ...@@ -339,9 +337,9 @@ class PartA2RPNHead(Anchor3DHead):
scores = torch.cat(scores, dim=0) scores = torch.cat(scores, dim=0)
cls_scores = torch.cat(cls_scores, dim=0) cls_scores = torch.cat(cls_scores, dim=0)
labels = torch.cat(labels, dim=0) labels = torch.cat(labels, dim=0)
if bboxes.shape[0] > max_num: if bboxes.shape[0] > cfg.nms_post:
_, inds = scores.sort(descending=True) _, inds = scores.sort(descending=True)
inds = inds[:max_num] inds = inds[:cfg.nms_post]
bboxes = bboxes[inds, :] bboxes = bboxes[inds, :]
labels = labels[inds] labels = labels[inds]
scores = scores[inds] scores = scores[inds]
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple
import torch import torch
from mmengine.model import BaseModule from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.models.builder import build_loss
from mmdet3d.models.layers import nms_bev, nms_normal_bev from mmdet3d.models.layers import nms_bev, nms_normal_bev
from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures import xywhr2xyxyr from mmdet3d.structures import xywhr2xyxyr
from mmdet3d.structures.bbox_3d import (DepthInstance3DBoxes, from mmdet3d.structures.bbox_3d import (BaseInstance3DBoxes,
DepthInstance3DBoxes,
LiDARInstance3DBoxes) LiDARInstance3DBoxes)
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils.typing import InstanceList
from mmdet.models.utils import multi_apply from mmdet.models.utils import multi_apply
...@@ -34,15 +40,15 @@ class PointRPNHead(BaseModule): ...@@ -34,15 +40,15 @@ class PointRPNHead(BaseModule):
""" """
def __init__(self, def __init__(self,
num_classes, num_classes: int,
train_cfg, train_cfg: dict,
test_cfg, test_cfg: dict,
pred_layer_cfg=None, pred_layer_cfg: Optional[dict] = None,
enlarge_width=0.1, enlarge_width: float = 0.1,
cls_loss=None, cls_loss: Optional[dict] = None,
bbox_loss=None, bbox_loss: Optional[dict] = None,
bbox_coder=None, bbox_coder: Optional[dict] = None,
init_cfg=None): init_cfg: Optional[dict] = None) -> None:
super().__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.num_classes = num_classes self.num_classes = num_classes
self.train_cfg = train_cfg self.train_cfg = train_cfg
...@@ -50,8 +56,8 @@ class PointRPNHead(BaseModule): ...@@ -50,8 +56,8 @@ class PointRPNHead(BaseModule):
self.enlarge_width = enlarge_width self.enlarge_width = enlarge_width
# build loss function # build loss function
self.bbox_loss = build_loss(bbox_loss) self.bbox_loss = MODELS.build(bbox_loss)
self.cls_loss = build_loss(cls_loss) self.cls_loss = MODELS.build(cls_loss)
# build box coder # build box coder
self.bbox_coder = TASK_UTILS.build(bbox_coder) self.bbox_coder = TASK_UTILS.build(bbox_coder)
...@@ -67,7 +73,8 @@ class PointRPNHead(BaseModule): ...@@ -67,7 +73,8 @@ class PointRPNHead(BaseModule):
input_channels=pred_layer_cfg.in_channels, input_channels=pred_layer_cfg.in_channels,
output_channels=self._get_reg_out_channels()) output_channels=self._get_reg_out_channels())
def _make_fc_layers(self, fc_cfg, input_channels, output_channels): def _make_fc_layers(self, fc_cfg: dict, input_channels: int,
output_channels: int) -> nn.Sequential:
"""Make fully connect layers. """Make fully connect layers.
Args: Args:
...@@ -102,7 +109,7 @@ class PointRPNHead(BaseModule): ...@@ -102,7 +109,7 @@ class PointRPNHead(BaseModule):
# torch.cos(yaw) (1), torch.sin(yaw) (1) # torch.cos(yaw) (1), torch.sin(yaw) (1)
return self.bbox_coder.code_size return self.bbox_coder.code_size
def forward(self, feat_dict): def forward(self, feat_dict: dict) -> Tuple[List[Tensor]]:
"""Forward pass. """Forward pass.
Args: Args:
...@@ -124,30 +131,35 @@ class PointRPNHead(BaseModule): ...@@ -124,30 +131,35 @@ class PointRPNHead(BaseModule):
batch_size, -1, self._get_reg_out_channels()) batch_size, -1, self._get_reg_out_channels())
return point_box_preds, point_cls_preds return point_box_preds, point_cls_preds
def loss(self, def loss_by_feat(
bbox_preds, self,
cls_preds, bbox_preds: List[Tensor],
points, cls_preds: List[Tensor],
gt_bboxes_3d, points: List[Tensor],
gt_labels_3d, batch_gt_instances_3d: InstanceList,
img_metas=None): batch_input_metas: Optional[List[dict]] = None,
batch_gt_instances_ignore: Optional[InstanceList] = None) -> Dict:
"""Compute loss. """Compute loss.
Args: Args:
bbox_preds (dict): Predictions from forward of PointRCNN RPN_Head. bbox_preds (list[torch.Tensor]): Predictions from forward of
cls_preds (dict): Classification from forward of PointRCNN PointRCNN RPN_Head.
RPN_Head. cls_preds (list[torch.Tensor]): Classification from forward of
PointRCNN RPN_Head.
points (list[torch.Tensor]): Input points. points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
bboxes of each sample. gt_instances_3d. It usually includes ``bboxes_3d`` and
gt_labels_3d (list[torch.Tensor]): Labels of each sample. ``labels_3d`` attributes.
img_metas (list[dict], Optional): Contain pcd and img's meta info. batch_input_metas (list[dict]): Contain pcd and img's meta info.
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None. Defaults to None.
Returns: Returns:
dict: Losses of PointRCNN RPN module. dict: Losses of PointRCNN RPN module.
""" """
targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d) targets = self.get_targets(points, batch_gt_instances_3d)
(bbox_targets, mask_targets, positive_mask, negative_mask, (bbox_targets, mask_targets, positive_mask, negative_mask,
box_loss_weights, point_targets) = targets box_loss_weights, point_targets) = targets
...@@ -169,25 +181,25 @@ class PointRPNHead(BaseModule): ...@@ -169,25 +181,25 @@ class PointRPNHead(BaseModule):
return losses return losses
def get_targets(self, points, gt_bboxes_3d, gt_labels_3d): def get_targets(self, points: List[Tensor],
batch_gt_instances_3d: InstanceList) -> Tuple[Tensor]:
"""Generate targets of PointRCNN RPN head. """Generate targets of PointRCNN RPN head.
Args: Args:
points (list[torch.Tensor]): Points of each batch. points (list[torch.Tensor]): Points in one batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
bboxes of each batch. gt_instances_3d. It usually includes ``bboxes_3d`` and
gt_labels_3d (list[torch.Tensor]): Labels of each batch. ``labels_3d`` attributes.
Returns: Returns:
tuple[torch.Tensor]: Targets of PointRCNN RPN head. tuple[torch.Tensor]: Targets of PointRCNN RPN head.
""" """
# find empty example gt_labels_3d = [
for index in range(len(gt_labels_3d)): instances.labels_3d for instances in batch_gt_instances_3d
if len(gt_labels_3d[index]) == 0: ]
fake_box = gt_bboxes_3d[index].tensor.new_zeros( gt_bboxes_3d = [
1, gt_bboxes_3d[index].tensor.shape[-1]) instances.bboxes_3d for instances in batch_gt_instances_3d
gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box) ]
gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1)
(bbox_targets, mask_targets, positive_mask, negative_mask, (bbox_targets, mask_targets, positive_mask, negative_mask,
point_targets) = multi_apply(self.get_targets_single, points, point_targets) = multi_apply(self.get_targets_single, points,
...@@ -202,7 +214,9 @@ class PointRPNHead(BaseModule): ...@@ -202,7 +214,9 @@ class PointRPNHead(BaseModule):
return (bbox_targets, mask_targets, positive_mask, negative_mask, return (bbox_targets, mask_targets, positive_mask, negative_mask,
box_loss_weights, point_targets) box_loss_weights, point_targets)
def get_targets_single(self, points, gt_bboxes_3d, gt_labels_3d): def get_targets_single(self, points: Tensor,
gt_bboxes_3d: BaseInstance3DBoxes,
gt_labels_3d: Tensor) -> Tuple[Tensor]:
"""Generate targets of PointRCNN RPN head for single batch. """Generate targets of PointRCNN RPN head for single batch.
Args: Args:
...@@ -243,24 +257,34 @@ class PointRPNHead(BaseModule): ...@@ -243,24 +257,34 @@ class PointRPNHead(BaseModule):
return (bbox_targets, mask_targets, positive_mask, negative_mask, return (bbox_targets, mask_targets, positive_mask, negative_mask,
point_targets) point_targets)
def get_bboxes(self, def predict_by_feat(self, points: Tensor, bbox_preds: List[Tensor],
points, cls_preds: List[Tensor], batch_input_metas: List[dict],
bbox_preds, cfg: Optional[dict]) -> InstanceList:
cls_preds,
input_metas,
rescale=False):
"""Generate bboxes from RPN head predictions. """Generate bboxes from RPN head predictions.
Args: Args:
points (torch.Tensor): Input points. points (torch.Tensor): Input points.
bbox_preds (dict): Regression predictions from PointRCNN head. bbox_preds (list[tensor]): Regression predictions from PointRCNN
cls_preds (dict): Class scores predictions from PointRCNN head. head.
input_metas (list[dict]): Point cloud and image's meta info. cls_preds (list[tensor]): Class scores predictions from PointRCNN
rescale (bool, optional): Whether to rescale bboxes. head.
Defaults to False. batch_input_metas (list[dict]): Batch inputs meta info.
cfg (ConfigDict, optional): Test / postprocessing
configuration.
Returns: Returns:
list[tuple[torch.Tensor]]: Bounding boxes, scores and labels. list[:obj:`InstanceData`]: Detection results of each sample
after the post process.
Each item usually contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes,
contains a tensor with shape (num_instances, C), where
C >= 7.
- cls_preds (torch.Tensor): Class score of each bbox.
""" """
sem_scores = cls_preds.sigmoid() sem_scores = cls_preds.sigmoid()
obj_scores = sem_scores.max(-1)[0] obj_scores = sem_scores.max(-1)[0]
...@@ -271,30 +295,40 @@ class PointRPNHead(BaseModule): ...@@ -271,30 +295,40 @@ class PointRPNHead(BaseModule):
for b in range(batch_size): for b in range(batch_size):
bbox3d = self.bbox_coder.decode(bbox_preds[b], points[b, ..., :3], bbox3d = self.bbox_coder.decode(bbox_preds[b], points[b, ..., :3],
object_class[b]) object_class[b])
mask = ~bbox3d.sum(dim=1).isinf()
bbox_selected, score_selected, labels, cls_preds_selected = \ bbox_selected, score_selected, labels, cls_preds_selected = \
self.class_agnostic_nms(obj_scores[b], sem_scores[b], bbox3d, self.class_agnostic_nms(obj_scores[b][mask],
points[b, ..., :3], input_metas[b]) sem_scores[b][mask, :],
bbox = input_metas[b]['box_type_3d']( bbox3d[mask, :],
bbox_selected.clone(), points[b, ..., :3][mask, :],
box_dim=bbox_selected.shape[-1], batch_input_metas[b],
with_yaw=True) cfg.nms_cfg)
results.append((bbox, score_selected, labels, cls_preds_selected)) bbox_selected = batch_input_metas[b]['box_type_3d'](
bbox_selected, box_dim=bbox_selected.shape[-1])
result = InstanceData()
result.bboxes_3d = bbox_selected
result.scores_3d = score_selected
result.labels_3d = labels
result.cls_preds = cls_preds_selected
results.append(result)
return results return results
def class_agnostic_nms(self, obj_scores, sem_scores, bbox, points, def class_agnostic_nms(self, obj_scores: Tensor, sem_scores: Tensor,
input_meta): bbox: Tensor, points: Tensor, input_meta: Dict,
nms_cfg: Dict) -> Tuple[Tensor]:
"""Class agnostic nms. """Class agnostic nms.
Args: Args:
obj_scores (torch.Tensor): Objectness score of bounding boxes. obj_scores (torch.Tensor): Objectness score of bounding boxes.
sem_scores (torch.Tensor): Semantic class score of bounding boxes. sem_scores (torch.Tensor): Semantic class score of bounding boxes.
bbox (torch.Tensor): Predicted bounding boxes. bbox (torch.Tensor): Predicted bounding boxes.
points (torch.Tensor): Input points.
input_meta (dict): Contain pcd and img's meta info.
nms_cfg (dict): NMS config dict.
Returns: Returns:
tuple[torch.Tensor]: Bounding boxes, scores and labels. tuple[torch.Tensor]: Bounding boxes, scores and labels.
""" """
nms_cfg = self.test_cfg.nms_cfg if not self.training \
else self.train_cfg.nms_cfg
if nms_cfg.use_rotate_nms: if nms_cfg.use_rotate_nms:
nms_func = nms_bev nms_func = nms_bev
else: else:
...@@ -323,14 +357,14 @@ class PointRPNHead(BaseModule): ...@@ -323,14 +357,14 @@ class PointRPNHead(BaseModule):
bbox = bbox[nonempty_box_mask] bbox = bbox[nonempty_box_mask]
if self.test_cfg.score_thr is not None: if nms_cfg.score_thr is not None:
score_thr = self.test_cfg.score_thr score_thr = nms_cfg.score_thr
keep = (obj_scores >= score_thr) keep = (obj_scores >= score_thr)
obj_scores = obj_scores[keep] obj_scores = obj_scores[keep]
sem_scores = sem_scores[keep] sem_scores = sem_scores[keep]
bbox = bbox.tensor[keep] bbox = bbox.tensor[keep]
if obj_scores.shape[0] > 0: if bbox.tensor.shape[0] > 0:
topk = min(nms_cfg.nms_pre, obj_scores.shape[0]) topk = min(nms_cfg.nms_pre, obj_scores.shape[0])
obj_scores_nms, indices = torch.topk(obj_scores, k=topk) obj_scores_nms, indices = torch.topk(obj_scores, k=topk)
bbox_for_nms = xywhr2xyxyr(bbox[indices].bev) bbox_for_nms = xywhr2xyxyr(bbox[indices].bev)
...@@ -343,15 +377,22 @@ class PointRPNHead(BaseModule): ...@@ -343,15 +377,22 @@ class PointRPNHead(BaseModule):
score_selected = obj_scores_nms[keep] score_selected = obj_scores_nms[keep]
cls_preds = sem_scores_nms[keep] cls_preds = sem_scores_nms[keep]
labels = torch.argmax(cls_preds, -1) labels = torch.argmax(cls_preds, -1)
if bbox_selected.shape[0] > nms_cfg.nms_post:
_, inds = score_selected.sort(descending=True)
inds = inds[:score_selected.nms_post]
bbox_selected = bbox_selected[inds, :]
labels = labels[inds]
score_selected = score_selected[inds]
cls_preds = cls_preds[inds, :]
else: else:
bbox_selected = bbox.tensor bbox_selected = bbox.tensor
score_selected = obj_scores.new_zeros([0]) score_selected = obj_scores.new_zeros([0])
labels = obj_scores.new_zeros([0]) labels = obj_scores.new_zeros([0])
cls_preds = obj_scores.new_zeros([0, sem_scores.shape[-1]]) cls_preds = obj_scores.new_zeros([0, sem_scores.shape[-1]])
return bbox_selected, score_selected, labels, cls_preds return bbox_selected, score_selected, labels, cls_preds
def _assign_targets_by_points_inside(self, bboxes_3d, points): def _assign_targets_by_points_inside(self, bboxes_3d: BaseInstance3DBoxes,
points: Tensor) -> Tuple[Tensor]:
"""Compute assignment by checking whether point is inside bbox. """Compute assignment by checking whether point is inside bbox.
Args: Args:
...@@ -379,3 +420,92 @@ class PointRPNHead(BaseModule): ...@@ -379,3 +420,92 @@ class PointRPNHead(BaseModule):
raise NotImplementedError('Unsupported bbox type!') raise NotImplementedError('Unsupported bbox type!')
return points_mask, assignment return points_mask, assignment
def predict(self, feats_dict: Dict,
batch_data_samples: SampleList) -> InstanceList:
"""Perform forward propagation of the 3D detection head and predict
detection results on the features of the upstream network.
Args:
feats_dict (dict): Contains features from the first stage.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
list[:obj:`InstanceData`]: Detection results of each sample
after the post process.
Each item usually contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes,
contains a tensor with shape (num_instances, C), where
C >= 7.
"""
batch_input_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
raw_points = feats_dict.pop('raw_points')
bbox_preds, cls_preds = self(feats_dict)
proposal_cfg = self.test_cfg
proposal_list = self.predict_by_feat(
raw_points,
bbox_preds,
cls_preds,
cfg=proposal_cfg,
batch_input_metas=batch_input_metas)
feats_dict['points_cls_preds'] = cls_preds
return proposal_list
def loss_and_predict(self,
feats_dict: Dict,
batch_data_samples: SampleList,
proposal_cfg: Optional[dict] = None,
**kwargs) -> Tuple[dict, InstanceList]:
"""Perform forward propagation of the head, then calculate loss and
predictions from the features and data samples.
Args:
feats_dict (dict): Contains features from the first stage.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
proposal_cfg (ConfigDict, optional): Proposal config.
Returns:
tuple: the return value is a tuple contains:
- losses: (dict[str, Tensor]): A dictionary of loss components.
- predictions (list[:obj:`InstanceData`]): Detection
results of each sample after the post process.
"""
batch_gt_instances_3d = []
batch_gt_instances_ignore = []
batch_input_metas = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
batch_gt_instances_3d.append(data_sample.gt_instances_3d)
batch_gt_instances_ignore.append(
data_sample.get('ignored_instances', None))
raw_points = feats_dict.pop('raw_points')
bbox_preds, cls_preds = self(feats_dict)
loss_inputs = (bbox_preds, cls_preds,
raw_points) + (batch_gt_instances_3d, batch_input_metas,
batch_gt_instances_ignore)
losses = self.loss_by_feat(*loss_inputs)
predictions = self.predict_by_feat(
raw_points,
bbox_preds,
cls_preds,
batch_input_metas=batch_input_metas,
cfg=proposal_cfg)
feats_dict['points_cls_preds'] = cls_preds
if predictions[0].bboxes_3d.tensor.isinf().any():
print(predictions)
return losses, predictions
...@@ -14,7 +14,6 @@ from mmdet3d.structures.bbox_3d import (DepthInstance3DBoxes, ...@@ -14,7 +14,6 @@ from mmdet3d.structures.bbox_3d import (DepthInstance3DBoxes,
LiDARInstance3DBoxes, LiDARInstance3DBoxes,
rotation_3d_in_axis) rotation_3d_in_axis)
from mmdet.models.utils import multi_apply from mmdet.models.utils import multi_apply
from ..builder import build_loss
from .vote_head import VoteHead from .vote_head import VoteHead
...@@ -76,8 +75,8 @@ class SSD3DHead(VoteHead): ...@@ -76,8 +75,8 @@ class SSD3DHead(VoteHead):
size_res_loss=size_res_loss, size_res_loss=size_res_loss,
semantic_loss=None, semantic_loss=None,
init_cfg=init_cfg) init_cfg=init_cfg)
self.corner_loss = build_loss(corner_loss) self.corner_loss = MODELS.build(corner_loss)
self.vote_loss = build_loss(vote_loss) self.vote_loss = MODELS.build(vote_loss)
self.num_candidates = vote_module_cfg['num_points'] self.num_candidates = vote_module_cfg['num_points']
def _get_cls_out_channels(self) -> int: def _get_cls_out_channels(self) -> int:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
import torch import torch
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
...@@ -23,14 +25,14 @@ class PointRCNN(TwoStage3DDetector): ...@@ -23,14 +25,14 @@ class PointRCNN(TwoStage3DDetector):
""" """
def __init__(self, def __init__(self,
backbone, backbone: dict,
neck=None, neck: Optional[dict] = None,
rpn_head=None, rpn_head: Optional[dict] = None,
roi_head=None, roi_head: Optional[dict] = None,
train_cfg=None, train_cfg: Optional[dict] = None,
test_cfg=None, test_cfg: Optional[dict] = None,
pretrained=None, init_cfg: Optional[dict] = None,
init_cfg=None): data_preprocessor: Optional[dict] = None) -> Optional:
super(PointRCNN, self).__init__( super(PointRCNN, self).__init__(
backbone=backbone, backbone=backbone,
neck=neck, neck=neck,
...@@ -38,111 +40,28 @@ class PointRCNN(TwoStage3DDetector): ...@@ -38,111 +40,28 @@ class PointRCNN(TwoStage3DDetector):
roi_head=roi_head, roi_head=roi_head,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
pretrained=pretrained, init_cfg=init_cfg,
init_cfg=init_cfg) data_preprocessor=data_preprocessor)
def extract_feat(self, points): def extract_feat(self, batch_inputs_dict: Dict) -> Dict:
"""Directly extract features from the backbone+neck. """Directly extract features from the backbone+neck.
Args: Args:
points (torch.Tensor): Input points. batch_inputs_dict (dict): The model input dict which include
'points', 'imgs' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor, optional): Image of each sample.
Returns: Returns:
dict: Features from the backbone+neck dict: Features from the backbone+neck and raw points.
""" """
points = torch.stack(batch_inputs_dict['points'])
x = self.backbone(points) x = self.backbone(points)
if self.with_neck: if self.with_neck:
x = self.neck(x) x = self.neck(x)
return x return dict(
fp_features=x['fp_features'].clone(),
def forward_train(self, points, input_metas, gt_bboxes_3d, gt_labels_3d): fp_points=x['fp_xyz'].clone(),
"""Forward of training. raw_points=points)
Args:
points (list[torch.Tensor]): Points of each batch.
input_metas (list[dict]): Meta information of each sample.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
Returns:
dict: Losses.
"""
losses = dict()
stack_points = torch.stack(points)
x = self.extract_feat(stack_points)
# features for rcnn
backbone_feats = x['fp_features'].clone()
backbone_xyz = x['fp_xyz'].clone()
rcnn_feats = {'features': backbone_feats, 'points': backbone_xyz}
bbox_preds, cls_preds = self.rpn_head(x)
rpn_loss = self.rpn_head.loss(
bbox_preds=bbox_preds,
cls_preds=cls_preds,
points=points,
gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=gt_labels_3d,
input_metas=input_metas)
losses.update(rpn_loss)
bbox_list = self.rpn_head.get_bboxes(stack_points, bbox_preds,
cls_preds, input_metas)
proposal_list = [
dict(
boxes_3d=bboxes,
scores_3d=scores,
labels_3d=labels,
cls_preds=preds_cls)
for bboxes, scores, labels, preds_cls in bbox_list
]
rcnn_feats.update({'points_cls_preds': cls_preds})
roi_losses = self.roi_head.forward_train(rcnn_feats, input_metas,
proposal_list, gt_bboxes_3d,
gt_labels_3d)
losses.update(roi_losses)
return losses
def simple_test(self, points, img_metas, imgs=None, rescale=False):
"""Forward of testing.
Args:
points (list[torch.Tensor]): Points of each sample.
img_metas (list[dict]): Image metas.
imgs (list[torch.Tensor], optional): Images of each sample.
Defaults to None.
rescale (bool, optional): Whether to rescale results.
Defaults to False.
Returns:
list: Predicted 3d boxes.
"""
stack_points = torch.stack(points)
x = self.extract_feat(stack_points)
# features for rcnn
backbone_feats = x['fp_features'].clone()
backbone_xyz = x['fp_xyz'].clone()
rcnn_feats = {'features': backbone_feats, 'points': backbone_xyz}
bbox_preds, cls_preds = self.rpn_head(x)
rcnn_feats.update({'points_cls_preds': cls_preds})
bbox_list = self.rpn_head.get_bboxes(
stack_points, bbox_preds, cls_preds, img_metas, rescale=rescale)
proposal_list = [
dict(
boxes_3d=bboxes,
scores_3d=scores,
labels_3d=labels,
cls_preds=preds_cls)
for bboxes, scores, labels, preds_cls in bbox_list
]
bbox_results = self.roi_head.simple_test(rcnn_feats, img_metas,
proposal_list)
return bbox_results
...@@ -100,8 +100,9 @@ class TwoStage3DDetector(Base3DDetector): ...@@ -100,8 +100,9 @@ class TwoStage3DDetector(Base3DDetector):
keys = rpn_losses.keys() keys = rpn_losses.keys()
for key in keys: for key in keys:
if 'loss' in key and 'rpn' not in key: if 'loss' in key and 'rpn' not in key:
rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) losses[f'rpn_{key}'] = rpn_losses[key]
losses.update(rpn_losses) else:
losses[key] = rpn_losses[key]
else: else:
# TODO: Not support currently, should have a check at Fast R-CNN # TODO: Not support currently, should have a check at Fast R-CNN
assert batch_data_samples[0].get('proposals', None) is not None assert batch_data_samples[0].get('proposals', None) is not None
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List from typing import Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -10,6 +10,7 @@ from torch import Tensor ...@@ -10,6 +10,7 @@ from torch import Tensor
from mmdet3d.models import make_sparse_convmodule from mmdet3d.models import make_sparse_convmodule
from mmdet3d.models.layers.spconv import IS_SPCONV2_AVAILABLE from mmdet3d.models.layers.spconv import IS_SPCONV2_AVAILABLE
from mmdet3d.utils.typing import InstanceList
from mmdet.models.utils import multi_apply from mmdet.models.utils import multi_apply
if IS_SPCONV2_AVAILABLE: if IS_SPCONV2_AVAILABLE:
...@@ -21,11 +22,11 @@ else: ...@@ -21,11 +22,11 @@ else:
from mmengine.model import BaseModule from mmengine.model import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.models.builder import build_loss
from mmdet3d.models.layers import nms_bev, nms_normal_bev from mmdet3d.models.layers import nms_bev, nms_normal_bev
from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures.bbox_3d import (LiDARInstance3DBoxes, from mmdet3d.structures.bbox_3d import (LiDARInstance3DBoxes,
rotation_3d_in_axis, xywhr2xyxyr) rotation_3d_in_axis, xywhr2xyxyr)
from mmdet3d.utils.typing import SamplingResultList
@MODELS.register_module() @MODELS.register_module()
...@@ -56,40 +57,40 @@ class PartA2BboxHead(BaseModule): ...@@ -56,40 +57,40 @@ class PartA2BboxHead(BaseModule):
conv_cfg (dict): Config dict of convolutional layers conv_cfg (dict): Config dict of convolutional layers
norm_cfg (dict): Config dict of normalization layers norm_cfg (dict): Config dict of normalization layers
loss_bbox (dict): Config dict of box regression loss. loss_bbox (dict): Config dict of box regression loss.
loss_cls (dict): Config dict of classifacation loss. loss_cls (dict, optional): Config dict of classifacation loss.
""" """
def __init__(self, def __init__(self,
num_classes, num_classes: int,
seg_in_channels, seg_in_channels: int,
part_in_channels, part_in_channels: int,
seg_conv_channels=None, seg_conv_channels: List[int] = None,
part_conv_channels=None, part_conv_channels: List[int] = None,
merge_conv_channels=None, merge_conv_channels: List[int] = None,
down_conv_channels=None, down_conv_channels: List[int] = None,
shared_fc_channels=None, shared_fc_channels: List[int] = None,
cls_channels=None, cls_channels: List[int] = None,
reg_channels=None, reg_channels: List[int] = None,
dropout_ratio=0.1, dropout_ratio: float = 0.1,
roi_feat_size=14, roi_feat_size: int = 14,
with_corner_loss=True, with_corner_loss: bool = True,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), bbox_coder: dict = dict(type='DeltaXYZWLHRBBoxCoder'),
conv_cfg=dict(type='Conv1d'), conv_cfg: dict = dict(type='Conv1d'),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), norm_cfg: dict = dict(type='BN1d', eps=1e-3, momentum=0.01),
loss_bbox=dict( loss_bbox: dict = dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_cls=dict( loss_cls: dict = dict(
type='CrossEntropyLoss', type='CrossEntropyLoss',
use_sigmoid=True, use_sigmoid=True,
reduction='none', reduction='none',
loss_weight=1.0), loss_weight=1.0),
init_cfg=None): init_cfg: dict = None) -> None:
super(PartA2BboxHead, self).__init__(init_cfg=init_cfg) super(PartA2BboxHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes self.num_classes = num_classes
self.with_corner_loss = with_corner_loss self.with_corner_loss = with_corner_loss
self.bbox_coder = TASK_UTILS.build(bbox_coder) self.bbox_coder = TASK_UTILS.build(bbox_coder)
self.loss_bbox = build_loss(loss_bbox) self.loss_bbox = MODELS.build(loss_bbox)
self.loss_cls = build_loss(loss_cls) self.loss_cls = MODELS.build(loss_cls)
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
assert down_conv_channels[-1] == shared_fc_channels[0] assert down_conv_channels[-1] == shared_fc_channels[0]
...@@ -244,7 +245,7 @@ class PartA2BboxHead(BaseModule): ...@@ -244,7 +245,7 @@ class PartA2BboxHead(BaseModule):
super().init_weights() super().init_weights()
normal_init(self.conv_reg[-1].conv, mean=0, std=0.001) normal_init(self.conv_reg[-1].conv, mean=0, std=0.001)
def forward(self, seg_feats, part_feats): def forward(self, seg_feats: Tensor, part_feats: Tensor) -> Tuple[Tensor]:
"""Forward pass. """Forward pass.
Args: Args:
...@@ -294,8 +295,10 @@ class PartA2BboxHead(BaseModule): ...@@ -294,8 +295,10 @@ class PartA2BboxHead(BaseModule):
return cls_score, bbox_pred return cls_score, bbox_pred
def loss(self, cls_score, bbox_pred, rois, labels, bbox_targets, def loss(self, cls_score: Tensor, bbox_pred: Tensor, rois: Tensor,
pos_gt_bboxes, reg_mask, label_weights, bbox_weights): labels: Tensor, bbox_targets: Tensor, pos_gt_bboxes: Tensor,
reg_mask: Tensor, label_weights: Tensor,
bbox_weights: Tensor) -> Dict:
"""Computing losses. """Computing losses.
Args: Args:
...@@ -329,9 +332,9 @@ class PartA2BboxHead(BaseModule): ...@@ -329,9 +332,9 @@ class PartA2BboxHead(BaseModule):
pos_inds = (reg_mask > 0) pos_inds = (reg_mask > 0)
if pos_inds.any() == 0: if pos_inds.any() == 0:
# fake a part loss # fake a part loss
losses['loss_bbox'] = loss_cls.new_tensor(0) losses['loss_bbox'] = loss_cls.new_tensor(0) * loss_cls.sum()
if self.with_corner_loss: if self.with_corner_loss:
losses['loss_corner'] = loss_cls.new_tensor(0) losses['loss_corner'] = loss_cls.new_tensor(0) * loss_cls.sum()
else: else:
pos_bbox_pred = bbox_pred.view(rcnn_batch_size, -1)[pos_inds] pos_bbox_pred = bbox_pred.view(rcnn_batch_size, -1)[pos_inds]
bbox_weights_flat = bbox_weights[pos_inds].view(-1, 1).repeat( bbox_weights_flat = bbox_weights[pos_inds].view(-1, 1).repeat(
...@@ -367,7 +370,10 @@ class PartA2BboxHead(BaseModule): ...@@ -367,7 +370,10 @@ class PartA2BboxHead(BaseModule):
return losses return losses
def get_targets(self, sampling_results, rcnn_train_cfg, concat=True): def get_targets(self,
sampling_results: SamplingResultList,
rcnn_train_cfg: dict,
concat: bool = True) -> Tuple[Tensor]:
"""Generate targets. """Generate targets.
Args: Args:
...@@ -407,7 +413,8 @@ class PartA2BboxHead(BaseModule): ...@@ -407,7 +413,8 @@ class PartA2BboxHead(BaseModule):
return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights,
bbox_weights) bbox_weights)
def _get_target_single(self, pos_bboxes, pos_gt_bboxes, ious, cfg): def _get_target_single(self, pos_bboxes: Tensor, pos_gt_bboxes: Tensor,
ious: Tensor, cfg: dict) -> Tuple[Tensor]:
"""Generate training targets for a single sample. """Generate training targets for a single sample.
Args: Args:
...@@ -472,7 +479,10 @@ class PartA2BboxHead(BaseModule): ...@@ -472,7 +479,10 @@ class PartA2BboxHead(BaseModule):
return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights,
bbox_weights) bbox_weights)
def get_corner_loss_lidar(self, pred_bbox3d, gt_bbox3d, delta=1.0): def get_corner_loss_lidar(self,
pred_bbox3d: Tensor,
gt_bbox3d: Tensor,
delta: float = 1.0) -> Tensor:
"""Calculate corner loss of given boxes. """Calculate corner loss of given boxes.
Args: Args:
...@@ -515,7 +525,7 @@ class PartA2BboxHead(BaseModule): ...@@ -515,7 +525,7 @@ class PartA2BboxHead(BaseModule):
class_labels: Tensor, class_labels: Tensor,
class_pred: Tensor, class_pred: Tensor,
input_metas: List[dict], input_metas: List[dict],
cfg: dict = None) -> List: cfg: dict = None) -> InstanceList:
"""Generate bboxes from bbox head predictions. """Generate bboxes from bbox head predictions.
Args: Args:
...@@ -528,7 +538,17 @@ class PartA2BboxHead(BaseModule): ...@@ -528,7 +538,17 @@ class PartA2BboxHead(BaseModule):
cfg (:obj:`ConfigDict`): Testing config. cfg (:obj:`ConfigDict`): Testing config.
Returns: Returns:
list[tuple]: Decoded bbox, scores and labels after nms. list[:obj:`InstanceData`]: Detection results of each sample
after the post process.
Each item usually contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes,
contains a tensor with shape (num_instances, C), where
C >= 7.
""" """
roi_batch_id = rois[..., 0] roi_batch_id = rois[..., 0]
roi_boxes = rois[..., 1:] # boxes without batch id roi_boxes = rois[..., 1:] # boxes without batch id
...@@ -570,12 +590,12 @@ class PartA2BboxHead(BaseModule): ...@@ -570,12 +590,12 @@ class PartA2BboxHead(BaseModule):
return result_list return result_list
def multi_class_nms(self, def multi_class_nms(self,
box_probs, box_probs: Tensor,
box_preds, box_preds: Tensor,
score_thr, score_thr: float,
nms_thr, nms_thr: float,
input_meta, input_meta: dict,
use_rotate_nms=True): use_rotate_nms: bool = True) -> Tensor:
"""Multi-class NMS for box head. """Multi-class NMS for box head.
Note: Note:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import build_conv_layer from mmcv.cnn.bricks import build_conv_layer
from mmengine.model import BaseModule, normal_init from mmengine.model import BaseModule, normal_init
from torch import nn as nn from mmengine.structures import InstanceData
from torch import Tensor
from mmdet3d.models.layers import nms_bev, nms_normal_bev from mmdet3d.models.layers import nms_bev, nms_normal_bev
from mmdet3d.models.layers.pointnet_modules import build_sa_module from mmdet3d.models.layers.pointnet_modules import build_sa_module
from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures.bbox_3d import (LiDARInstance3DBoxes, from mmdet3d.structures.bbox_3d import (LiDARInstance3DBoxes,
rotation_3d_in_axis, xywhr2xyxyr) rotation_3d_in_axis, xywhr2xyxyr)
from mmdet3d.utils.typing import InstanceList, SamplingResultList
from mmdet.models.utils import multi_apply from mmdet.models.utils import multi_apply
...@@ -24,17 +29,17 @@ class PointRCNNBboxHead(BaseModule): ...@@ -24,17 +29,17 @@ class PointRCNNBboxHead(BaseModule):
mlp_channels (list[int]): the number of mlp channels mlp_channels (list[int]): the number of mlp channels
pred_layer_cfg (dict, optional): Config of classfication and pred_layer_cfg (dict, optional): Config of classfication and
regression prediction layers. Defaults to None. regression prediction layers. Defaults to None.
num_points (tuple, optional): The number of points which each SA num_points (tuple): The number of points which each SA
module samples. Defaults to (128, 32, -1). module samples. Defaults to (128, 32, -1).
radius (tuple, optional): Sampling radius of each SA module. radius (tuple): Sampling radius of each SA module.
Defaults to (0.2, 0.4, 100). Defaults to (0.2, 0.4, 100).
num_samples (tuple, optional): The number of samples for ball query num_samples (tuple): The number of samples for ball query
in each SA module. Defaults to (64, 64, 64). in each SA module. Defaults to (64, 64, 64).
sa_channels (tuple, optional): Out channels of each mlp in SA module. sa_channels (tuple): Out channels of each mlp in SA module.
Defaults to ((128, 128, 128), (128, 128, 256), (256, 256, 512)). Defaults to ((128, 128, 128), (128, 128, 256), (256, 256, 512)).
bbox_coder (dict, optional): Config dict of box coders. bbox_coder (dict): Config dict of box coders.
Defaults to dict(type='DeltaXYZWLHRBBoxCoder'). Defaults to dict(type='DeltaXYZWLHRBBoxCoder').
sa_cfg (dict, optional): Config of set abstraction module, which may sa_cfg (dict): Config of set abstraction module, which may
contain the following keys and values: contain the following keys and values:
- pool_mod (str): Pool method ('max' or 'avg') for SA modules. - pool_mod (str): Pool method ('max' or 'avg') for SA modules.
...@@ -43,52 +48,53 @@ class PointRCNNBboxHead(BaseModule): ...@@ -43,52 +48,53 @@ class PointRCNNBboxHead(BaseModule):
each SA module. each SA module.
Defaults to dict(type='PointSAModule', pool_mod='max', Defaults to dict(type='PointSAModule', pool_mod='max',
use_xyz=True). use_xyz=True).
conv_cfg (dict, optional): Config dict of convolutional layers. conv_cfg (dict): Config dict of convolutional layers.
Defaults to dict(type='Conv1d'). Defaults to dict(type='Conv1d').
norm_cfg (dict, optional): Config dict of normalization layers. norm_cfg (dict): Config dict of normalization layers.
Defaults to dict(type='BN1d'). Defaults to dict(type='BN1d').
act_cfg (dict, optional): Config dict of activation layers. act_cfg (dict): Config dict of activation layers.
Defaults to dict(type='ReLU'). Defaults to dict(type='ReLU').
bias (str, optional): Type of bias. Defaults to 'auto'. bias (str): Type of bias. Defaults to 'auto'.
loss_bbox (dict, optional): Config of regression loss function. loss_bbox (dict): Config of regression loss function.
Defaults to dict(type='SmoothL1Loss', beta=1.0 / 9.0, Defaults to dict(type='SmoothL1Loss', beta=1.0 / 9.0,
reduction='sum', loss_weight=1.0). reduction='sum', loss_weight=1.0).
loss_cls (dict, optional): Config of classification loss function. loss_cls (dict): Config of classification loss function.
Defaults to dict(type='CrossEntropyLoss', use_sigmoid=True, Defaults to dict(type='CrossEntropyLoss', use_sigmoid=True,
reduction='sum', loss_weight=1.0). reduction='sum', loss_weight=1.0).
with_corner_loss (bool, optional): Whether using corner loss. with_corner_loss (bool): Whether using corner loss.
Defaults to True. Defaults to True.
init_cfg (dict, optional): Config of initialization. Defaults to None. init_cfg (dict, optional): Config of initialization. Defaults to None.
""" """
def __init__( def __init__(self,
self, num_classes: dict,
num_classes, in_channels: dict,
in_channels, mlp_channels: dict,
mlp_channels, pred_layer_cfg: Optional[dict] = None,
pred_layer_cfg=None, num_points: dict = (128, 32, -1),
num_points=(128, 32, -1), radius: dict = (0.2, 0.4, 100),
radius=(0.2, 0.4, 100), num_samples: dict = (64, 64, 64),
num_samples=(64, 64, 64), sa_channels: dict = ((128, 128, 128), (128, 128, 256),
sa_channels=((128, 128, 128), (128, 128, 256), (256, 256, 512)), (256, 256, 512)),
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), bbox_coder: dict = dict(type='DeltaXYZWLHRBBoxCoder'),
sa_cfg=dict(type='PointSAModule', pool_mod='max', use_xyz=True), sa_cfg: dict = dict(
conv_cfg=dict(type='Conv1d'), type='PointSAModule', pool_mod='max', use_xyz=True),
norm_cfg=dict(type='BN1d'), conv_cfg: dict = dict(type='Conv1d'),
act_cfg=dict(type='ReLU'), norm_cfg: dict = dict(type='BN1d'),
bias='auto', act_cfg: dict = dict(type='ReLU'),
loss_bbox=dict( bias: str = 'auto',
type='SmoothL1Loss', loss_bbox: dict = dict(
beta=1.0 / 9.0, type='SmoothL1Loss',
reduction='sum', beta=1.0 / 9.0,
loss_weight=1.0), reduction='sum',
loss_cls=dict( loss_weight=1.0),
type='CrossEntropyLoss', loss_cls: dict = dict(
use_sigmoid=True, type='CrossEntropyLoss',
reduction='sum', use_sigmoid=True,
loss_weight=1.0), reduction='sum',
with_corner_loss=True, loss_weight=1.0),
init_cfg=None): with_corner_loss: bool = True,
init_cfg: Optional[dict] = None) -> None:
super(PointRCNNBboxHead, self).__init__(init_cfg=init_cfg) super(PointRCNNBboxHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes self.num_classes = num_classes
self.num_sa = len(sa_channels) self.num_sa = len(sa_channels)
...@@ -169,7 +175,8 @@ class PointRCNNBboxHead(BaseModule): ...@@ -169,7 +175,8 @@ class PointRCNNBboxHead(BaseModule):
if init_cfg is None: if init_cfg is None:
self.init_cfg = dict(type='Xavier', layer=['Conv2d', 'Conv1d']) self.init_cfg = dict(type='Xavier', layer=['Conv2d', 'Conv1d'])
def _add_conv_branch(self, in_channels, conv_channels): def _add_conv_branch(self, in_channels: int,
conv_channels: tuple) -> nn.Sequential:
"""Add shared or separable branch. """Add shared or separable branch.
Args: Args:
...@@ -203,7 +210,7 @@ class PointRCNNBboxHead(BaseModule): ...@@ -203,7 +210,7 @@ class PointRCNNBboxHead(BaseModule):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
normal_init(self.conv_reg.weight, mean=0, std=0.001) normal_init(self.conv_reg.weight, mean=0, std=0.001)
def forward(self, feats): def forward(self, feats: Tensor) -> Tuple[Tensor]:
"""Forward pass. """Forward pass.
Args: Args:
...@@ -239,8 +246,10 @@ class PointRCNNBboxHead(BaseModule): ...@@ -239,8 +246,10 @@ class PointRCNNBboxHead(BaseModule):
rcnn_reg = rcnn_reg.transpose(1, 2).contiguous().squeeze(dim=1) rcnn_reg = rcnn_reg.transpose(1, 2).contiguous().squeeze(dim=1)
return rcnn_cls, rcnn_reg return rcnn_cls, rcnn_reg
def loss(self, cls_score, bbox_pred, rois, labels, bbox_targets, def loss(self, cls_score: Tensor, bbox_pred: Tensor, rois: Tensor,
pos_gt_bboxes, reg_mask, label_weights, bbox_weights): labels: Tensor, bbox_targets: Tensor, pos_gt_bboxes: Tensor,
reg_mask: Tensor, label_weights: Tensor,
bbox_weights: Tensor) -> Dict:
"""Computing losses. """Computing losses.
Args: Args:
...@@ -302,15 +311,17 @@ class PointRCNNBboxHead(BaseModule): ...@@ -302,15 +311,17 @@ class PointRCNNBboxHead(BaseModule):
# calculate corner loss # calculate corner loss
loss_corner = self.get_corner_loss_lidar(pred_boxes3d, loss_corner = self.get_corner_loss_lidar(pred_boxes3d,
pos_gt_bboxes) pos_gt_bboxes).mean()
losses['loss_corner'] = loss_corner losses['loss_corner'] = loss_corner
else: else:
losses['loss_corner'] = loss_cls.new_tensor(0) losses['loss_corner'] = loss_cls.new_tensor(0) * loss_cls.sum()
return losses return losses
def get_corner_loss_lidar(self, pred_bbox3d, gt_bbox3d, delta=1.0): def get_corner_loss_lidar(self,
pred_bbox3d: Tensor,
gt_bbox3d: Tensor,
delta: float = 1.0) -> Tensor:
"""Calculate corner loss of given boxes. """Calculate corner loss of given boxes.
Args: Args:
...@@ -340,19 +351,24 @@ class PointRCNNBboxHead(BaseModule): ...@@ -340,19 +351,24 @@ class PointRCNNBboxHead(BaseModule):
torch.norm(pred_box_corners - gt_box_corners_flip, dim=2)) torch.norm(pred_box_corners - gt_box_corners_flip, dim=2))
# huber loss # huber loss
abs_error = corner_dist.abs() abs_error = corner_dist.abs()
quadratic = abs_error.clamp(max=delta) # quadratic = abs_error.clamp(max=delta)
linear = (abs_error - quadratic) # linear = (abs_error - quadratic)
corner_loss = 0.5 * quadratic**2 + delta * linear # corner_loss = 0.5 * quadratic**2 + delta * linear
return corner_loss.mean(dim=1) loss = torch.where(abs_error < delta, 0.5 * abs_error**2 / delta,
abs_error - 0.5 * delta)
def get_targets(self, sampling_results, rcnn_train_cfg, concat=True): return loss.mean(dim=1)
def get_targets(self,
sampling_results: SamplingResultList,
rcnn_train_cfg: dict,
concat: bool = True) -> Tuple[Tensor]:
"""Generate targets. """Generate targets.
Args: Args:
sampling_results (list[:obj:`SamplingResult`]): sampling_results (list[:obj:`SamplingResult`]):
Sampled results from rois. Sampled results from rois.
rcnn_train_cfg (:obj:`ConfigDict`): Training config of rcnn. rcnn_train_cfg (:obj:`ConfigDict`): Training config of rcnn.
concat (bool, optional): Whether to concatenate targets between concat (bool): Whether to concatenate targets between
batches. Defaults to True. batches. Defaults to True.
Returns: Returns:
...@@ -385,7 +401,8 @@ class PointRCNNBboxHead(BaseModule): ...@@ -385,7 +401,8 @@ class PointRCNNBboxHead(BaseModule):
return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights,
bbox_weights) bbox_weights)
def _get_target_single(self, pos_bboxes, pos_gt_bboxes, ious, cfg): def _get_target_single(self, pos_bboxes: Tensor, pos_gt_bboxes: Tensor,
ious: Tensor, cfg: dict) -> Tuple[Tensor]:
"""Generate training targets for a single sample. """Generate training targets for a single sample.
Args: Args:
...@@ -449,13 +466,13 @@ class PointRCNNBboxHead(BaseModule): ...@@ -449,13 +466,13 @@ class PointRCNNBboxHead(BaseModule):
return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights,
bbox_weights) bbox_weights)
def get_bboxes(self, def get_results(self,
rois, rois: Tensor,
cls_score, cls_score: Tensor,
bbox_pred, bbox_pred: Tensor,
class_labels, class_labels: Tensor,
img_metas, input_metas: List[dict],
cfg=None): cfg: dict = None) -> InstanceList:
"""Generate bboxes from bbox head predictions. """Generate bboxes from bbox head predictions.
Args: Args:
...@@ -463,12 +480,22 @@ class PointRCNNBboxHead(BaseModule): ...@@ -463,12 +480,22 @@ class PointRCNNBboxHead(BaseModule):
cls_score (torch.Tensor): Scores of bounding boxes. cls_score (torch.Tensor): Scores of bounding boxes.
bbox_pred (torch.Tensor): Bounding boxes predictions bbox_pred (torch.Tensor): Bounding boxes predictions
class_labels (torch.Tensor): Label of classes class_labels (torch.Tensor): Label of classes
img_metas (list[dict]): Point cloud and image's meta info. input_metas (list[dict]): Point cloud and image's meta info.
cfg (:obj:`ConfigDict`, optional): Testing config. cfg (:obj:`ConfigDict`, optional): Testing config.
Defaults to None. Defaults to None.
Returns: Returns:
list[tuple]: Decoded bbox, scores and labels after nms. list[:obj:`InstanceData`]: Detection results of each sample
after the post process.
Each item usually contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes,
contains a tensor with shape (num_instances, C), where
C >= 7.
""" """
roi_batch_id = rois[..., 0] roi_batch_id = rois[..., 0]
roi_boxes = rois[..., 1:] # boxes without batch id roi_boxes = rois[..., 1:] # boxes without batch id
...@@ -494,25 +521,27 @@ class PointRCNNBboxHead(BaseModule): ...@@ -494,25 +521,27 @@ class PointRCNNBboxHead(BaseModule):
cur_rcnn_boxes3d = rcnn_boxes3d[roi_batch_id == batch_id] cur_rcnn_boxes3d = rcnn_boxes3d[roi_batch_id == batch_id]
keep = self.multi_class_nms(cur_box_prob, cur_rcnn_boxes3d, keep = self.multi_class_nms(cur_box_prob, cur_rcnn_boxes3d,
cfg.score_thr, cfg.nms_thr, cfg.score_thr, cfg.nms_thr,
img_metas[batch_id], input_metas[batch_id],
cfg.use_rotate_nms) cfg.use_rotate_nms)
selected_bboxes = cur_rcnn_boxes3d[keep] selected_bboxes = cur_rcnn_boxes3d[keep]
selected_label_preds = cur_class_labels[keep] selected_label_preds = cur_class_labels[keep]
selected_scores = cur_cls_score[keep] selected_scores = cur_cls_score[keep]
results = InstanceData()
results.bboxes_3d = input_metas[batch_id]['box_type_3d'](
selected_bboxes, selected_bboxes.shape[-1])
results.scores_3d = selected_scores
results.labels_3d = selected_label_preds
result_list.append( result_list.append(results)
(img_metas[batch_id]['box_type_3d'](selected_bboxes,
self.bbox_coder.code_size),
selected_scores, selected_label_preds))
return result_list return result_list
def multi_class_nms(self, def multi_class_nms(self,
box_probs, box_probs: Tensor,
box_preds, box_preds: Tensor,
score_thr, score_thr: float,
nms_thr, nms_thr: float,
input_meta, input_meta: dict,
use_rotate_nms=True): use_rotate_nms: bool = True) -> Tensor:
"""Multi-class NMS for box head. """Multi-class NMS for box head.
Note: Note:
...@@ -527,7 +556,7 @@ class PointRCNNBboxHead(BaseModule): ...@@ -527,7 +556,7 @@ class PointRCNNBboxHead(BaseModule):
score_thr (float): Threshold of scores. score_thr (float): Threshold of scores.
nms_thr (float): Threshold for NMS. nms_thr (float): Threshold for NMS.
input_meta (dict): Meta information of the current sample. input_meta (dict): Meta information of the current sample.
use_rotate_nms (bool, optional): Whether to use rotated nms. use_rotate_nms (bool): Whether to use rotated nms.
Defaults to True. Defaults to True.
Returns: Returns:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple
import torch import torch
from mmengine.model import BaseModule from mmengine.model import BaseModule
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.models.builder import build_loss
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures.bbox_3d import rotation_3d_in_axis from mmdet3d.structures.bbox_3d import BaseInstance3DBoxes, rotation_3d_in_axis
from mmdet3d.utils import InstanceList from mmdet3d.utils import InstanceList
from mmdet.models.utils import multi_apply from mmdet.models.utils import multi_apply
...@@ -26,23 +28,23 @@ class PointwiseSemanticHead(BaseModule): ...@@ -26,23 +28,23 @@ class PointwiseSemanticHead(BaseModule):
loss_part (dict): Config of part prediction loss. loss_part (dict): Config of part prediction loss.
""" """
def __init__(self, def __init__(
in_channels, self,
num_classes=3, in_channels: int,
extra_width=0.2, num_classes: int = 3,
seg_score_thr=0.3, extra_width: float = 0.2,
init_cfg=None, seg_score_thr: float = 0.3,
loss_seg=dict( init_cfg: Optional[dict] = None,
type='FocalLoss', loss_seg: dict = dict(
use_sigmoid=True, type='FocalLoss',
reduction='sum', use_sigmoid=True,
gamma=2.0, reduction='sum',
alpha=0.25, gamma=2.0,
loss_weight=1.0), alpha=0.25,
loss_part=dict( loss_weight=1.0),
type='CrossEntropyLoss', loss_part: dict = dict(
use_sigmoid=True, type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)
loss_weight=1.0)): ) -> None:
super(PointwiseSemanticHead, self).__init__(init_cfg=init_cfg) super(PointwiseSemanticHead, self).__init__(init_cfg=init_cfg)
self.extra_width = extra_width self.extra_width = extra_width
self.num_classes = num_classes self.num_classes = num_classes
...@@ -50,10 +52,10 @@ class PointwiseSemanticHead(BaseModule): ...@@ -50,10 +52,10 @@ class PointwiseSemanticHead(BaseModule):
self.seg_cls_layer = nn.Linear(in_channels, 1, bias=True) self.seg_cls_layer = nn.Linear(in_channels, 1, bias=True)
self.seg_reg_layer = nn.Linear(in_channels, 3, bias=True) self.seg_reg_layer = nn.Linear(in_channels, 3, bias=True)
self.loss_seg = build_loss(loss_seg) self.loss_seg = MODELS.build(loss_seg)
self.loss_part = build_loss(loss_part) self.loss_part = MODELS.build(loss_part)
def forward(self, x): def forward(self, x: Tensor) -> Dict[str, Tensor]:
"""Forward pass. """Forward pass.
Args: Args:
...@@ -79,7 +81,9 @@ class PointwiseSemanticHead(BaseModule): ...@@ -79,7 +81,9 @@ class PointwiseSemanticHead(BaseModule):
return dict( return dict(
seg_preds=seg_preds, part_preds=part_preds, part_feats=part_feats) seg_preds=seg_preds, part_preds=part_preds, part_feats=part_feats)
def get_targets_single(self, voxel_centers, gt_bboxes_3d, gt_labels_3d): def get_targets_single(self, voxel_centers: Tensor,
gt_bboxes_3d: BaseInstance3DBoxes,
gt_labels_3d: Tensor) -> Tuple[Tensor]:
"""generate segmentation and part prediction targets for a single """generate segmentation and part prediction targets for a single
sample. sample.
...@@ -162,7 +166,8 @@ class PointwiseSemanticHead(BaseModule): ...@@ -162,7 +166,8 @@ class PointwiseSemanticHead(BaseModule):
part_targets = torch.cat(part_targets, dim=0) part_targets = torch.cat(part_targets, dim=0)
return dict(seg_targets=seg_targets, part_targets=part_targets) return dict(seg_targets=seg_targets, part_targets=part_targets)
def loss(self, semantic_results, semantic_targets): def loss(self, semantic_results: dict,
semantic_targets: dict) -> Dict[str, Tensor]:
"""Calculate point-wise segmentation and part prediction losses. """Calculate point-wise segmentation and part prediction losses.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
...@@ -12,6 +12,7 @@ from torch.nn import functional as F ...@@ -12,6 +12,7 @@ from torch.nn import functional as F
from mmdet3d.models.layers import VoteModule, build_sa_module from mmdet3d.models.layers import VoteModule, build_sa_module
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures import Det3DDataSample from mmdet3d.structures import Det3DDataSample
from mmdet3d.structures.bbox_3d import BaseInstance3DBoxes
from mmdet.models.utils import multi_apply from mmdet.models.utils import multi_apply
...@@ -26,39 +27,42 @@ class PrimitiveHead(BaseModule): ...@@ -26,39 +27,42 @@ class PrimitiveHead(BaseModule):
available mode ['z', 'xy', 'line']. available mode ['z', 'xy', 'line'].
bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
decoding boxes. decoding boxes.
train_cfg (dict): Config for training. train_cfg (dict, optional): Config for training.
test_cfg (dict): Config for testing. test_cfg (dict, optional): Config for testing.
vote_module_cfg (dict): Config of VoteModule for point-wise votes. vote_module_cfg (dict, optional): Config of VoteModule for point-wise
vote_aggregation_cfg (dict): Config of vote aggregation layer. votes.
vote_aggregation_cfg (dict, optional): Config of vote aggregation
layer.
feat_channels (tuple[int]): Convolution channels of feat_channels (tuple[int]): Convolution channels of
prediction layer. prediction layer.
upper_thresh (float): Threshold for line matching. upper_thresh (float): Threshold for line matching.
surface_thresh (float): Threshold for surface matching. surface_thresh (float): Threshold for surface matching.
conv_cfg (dict): Config of convolution in prediction layer. conv_cfg (dict, optional): Config of convolution in prediction layer.
norm_cfg (dict): Config of BN in prediction layer. norm_cfg (dict, optional): Config of BN in prediction layer.
objectness_loss (dict): Config of objectness loss. objectness_loss (dict, optional): Config of objectness loss.
center_loss (dict): Config of center loss. center_loss (dict, optional): Config of center loss.
semantic_loss (dict): Config of point-wise semantic segmentation loss. semantic_loss (dict, optional): Config of point-wise semantic
segmentation loss.
""" """
def __init__(self, def __init__(self,
num_dims: int, num_dims: int,
num_classes: int, num_classes: int,
primitive_mode: str, primitive_mode: str,
train_cfg: dict = None, train_cfg: Optional[dict] = None,
test_cfg: dict = None, test_cfg: Optional[dict] = None,
vote_module_cfg: dict = None, vote_module_cfg: Optional[dict] = None,
vote_aggregation_cfg: dict = None, vote_aggregation_cfg: Optional[dict] = None,
feat_channels: tuple = (128, 128), feat_channels: tuple = (128, 128),
upper_thresh: float = 100.0, upper_thresh: float = 100.0,
surface_thresh: float = 0.5, surface_thresh: float = 0.5,
conv_cfg: dict = dict(type='Conv1d'), conv_cfg: dict = dict(type='Conv1d'),
norm_cfg: dict = dict(type='BN1d'), norm_cfg: dict = dict(type='BN1d'),
objectness_loss: dict = None, objectness_loss: Optional[dict] = None,
center_loss: dict = None, center_loss: Optional[dict] = None,
semantic_reg_loss: dict = None, semantic_reg_loss: Optional[dict] = None,
semantic_cls_loss: dict = None, semantic_cls_loss: Optional[dict] = None,
init_cfg: dict = None): init_cfg: Optional[dict] = None):
super(PrimitiveHead, self).__init__(init_cfg=init_cfg) super(PrimitiveHead, self).__init__(init_cfg=init_cfg)
# bounding boxes centers, face centers and edge centers # bounding boxes centers, face centers and edge centers
assert primitive_mode in ['z', 'xy', 'line'] assert primitive_mode in ['z', 'xy', 'line']
...@@ -126,7 +130,7 @@ class PrimitiveHead(BaseModule): ...@@ -126,7 +130,7 @@ class PrimitiveHead(BaseModule):
assert sample_mode in ['vote', 'seed', 'random'] assert sample_mode in ['vote', 'seed', 'random']
return sample_mode return sample_mode
def forward(self, feats_dict): def forward(self, feats_dict: dict) -> dict:
"""Forward pass. """Forward pass.
Args: Args:
...@@ -255,10 +259,8 @@ class PrimitiveHead(BaseModule): ...@@ -255,10 +259,8 @@ class PrimitiveHead(BaseModule):
attributes. attributes.
batch_pts_semantic_mask (list[tensor]): Semantic mask batch_pts_semantic_mask (list[tensor]): Semantic mask
of points cloud. Defaults to None. of points cloud. Defaults to None.
batch_pts_semantic_mask (list[tensor]): Instance mask batch_pts_instance_mask (list[tensor]): Instance mask
of points cloud. Defaults to None. of points cloud. Defaults to None.
batch_input_metas (list[dict]): Contain pcd and img's meta info.
ret_target (bool): Return targets or not. Defaults to False.
Returns: Returns:
dict: Losses of Primitive Head. dict: Losses of Primitive Head.
...@@ -392,12 +394,13 @@ class PrimitiveHead(BaseModule): ...@@ -392,12 +394,13 @@ class PrimitiveHead(BaseModule):
return (point_mask, point_offset, gt_primitive_center, return (point_mask, point_offset, gt_primitive_center,
gt_primitive_semantic, gt_sem_cls_label, gt_votes_mask) gt_primitive_semantic, gt_sem_cls_label, gt_votes_mask)
def get_targets_single(self, def get_targets_single(
points, self,
gt_bboxes_3d, points: torch.Tensor,
gt_labels_3d, gt_bboxes_3d: BaseInstance3DBoxes,
pts_semantic_mask=None, gt_labels_3d: torch.Tensor,
pts_instance_mask=None): pts_semantic_mask: torch.Tensor = None,
pts_instance_mask: torch.Tensor = None) -> Tuple[torch.Tensor]:
"""Generate targets of primitive head for single batch. """Generate targets of primitive head for single batch.
Args: Args:
...@@ -668,7 +671,8 @@ class PrimitiveHead(BaseModule): ...@@ -668,7 +671,8 @@ class PrimitiveHead(BaseModule):
return (point_mask, point_sem, point_offset) return (point_mask, point_sem, point_offset)
def primitive_decode_scores(self, predictions, aggregated_points): def primitive_decode_scores(self, predictions: torch.Tensor,
aggregated_points: torch.Tensor) -> dict:
"""Decode predicted parts to primitive head. """Decode predicted parts to primitive head.
Args: Args:
...@@ -696,7 +700,7 @@ class PrimitiveHead(BaseModule): ...@@ -696,7 +700,7 @@ class PrimitiveHead(BaseModule):
return ret_dict return ret_dict
def check_horizon(self, points): def check_horizon(self, points: torch.Tensor) -> bool:
"""Check whether is a horizontal plane. """Check whether is a horizontal plane.
Args: Args:
...@@ -709,7 +713,8 @@ class PrimitiveHead(BaseModule): ...@@ -709,7 +713,8 @@ class PrimitiveHead(BaseModule):
(points[1][-1] == points[2][-1]) and \ (points[1][-1] == points[2][-1]) and \
(points[2][-1] == points[3][-1]) (points[2][-1] == points[3][-1])
def check_dist(self, plane_equ, points): def check_dist(self, plane_equ: torch.Tensor,
points: torch.Tensor) -> tuple:
"""Whether the mean of points to plane distance is lower than thresh. """Whether the mean of points to plane distance is lower than thresh.
Args: Args:
...@@ -722,7 +727,8 @@ class PrimitiveHead(BaseModule): ...@@ -722,7 +727,8 @@ class PrimitiveHead(BaseModule):
return (points[:, 2] + return (points[:, 2] +
plane_equ[-1]).sum() / 4.0 < self.train_cfg['lower_thresh'] plane_equ[-1]).sum() / 4.0 < self.train_cfg['lower_thresh']
def point2line_dist(self, points, pts_a, pts_b): def point2line_dist(self, points: torch.Tensor, pts_a: torch.Tensor,
pts_b: torch.Tensor) -> torch.Tensor:
"""Calculate the distance from point to line. """Calculate the distance from point to line.
Args: Args:
...@@ -741,7 +747,11 @@ class PrimitiveHead(BaseModule): ...@@ -741,7 +747,11 @@ class PrimitiveHead(BaseModule):
return dist return dist
def match_point2line(self, points, corners, with_yaw, mode='bottom'): def match_point2line(self,
points: torch.Tensor,
corners: torch.Tensor,
with_yaw: bool,
mode: str = 'bottom') -> tuple:
"""Match points to corresponding line. """Match points to corresponding line.
Args: Args:
...@@ -782,7 +792,8 @@ class PrimitiveHead(BaseModule): ...@@ -782,7 +792,8 @@ class PrimitiveHead(BaseModule):
selected_list = [sel1, sel2, sel3, sel4] selected_list = [sel1, sel2, sel3, sel4]
return selected_list return selected_list
def match_point2plane(self, plane, points): def match_point2plane(self, plane: torch.Tensor,
points: torch.Tensor) -> tuple:
"""Match points to plane. """Match points to plane.
Args: Args:
...@@ -800,10 +811,14 @@ class PrimitiveHead(BaseModule): ...@@ -800,10 +811,14 @@ class PrimitiveHead(BaseModule):
min_dist) < self.train_cfg['dist_thresh'] min_dist) < self.train_cfg['dist_thresh']
return point2plane_dist, selected return point2plane_dist, selected
def compute_primitive_loss(self, primitive_center, primitive_semantic, def compute_primitive_loss(self, primitive_center: torch.Tensor,
semantic_scores, num_proposal, primitive_semantic: torch.Tensor,
gt_primitive_center, gt_primitive_semantic, semantic_scores: torch.Tensor,
gt_sem_cls_label, gt_primitive_mask): num_proposal: torch.Tensor,
gt_primitive_center: torch.Tensor,
gt_primitive_semantic: torch.Tensor,
gt_sem_cls_label: torch.Tensor,
gt_primitive_mask: torch.Tensor) -> Tuple:
"""Compute loss of primitive module. """Compute loss of primitive module.
Args: Args:
...@@ -849,7 +864,8 @@ class PrimitiveHead(BaseModule): ...@@ -849,7 +864,8 @@ class PrimitiveHead(BaseModule):
return center_loss, size_loss, sem_cls_loss return center_loss, size_loss, sem_cls_loss
def get_primitive_center(self, pred_flag, center): def get_primitive_center(self, pred_flag: torch.Tensor,
center: torch.Tensor) -> Tuple:
"""Generate primitive center from predictions. """Generate primitive center from predictions.
Args: Args:
...@@ -869,17 +885,17 @@ class PrimitiveHead(BaseModule): ...@@ -869,17 +885,17 @@ class PrimitiveHead(BaseModule):
return center, pred_indices return center, pred_indices
def _assign_primitive_line_targets(self, def _assign_primitive_line_targets(self,
point_mask, point_mask: torch.Tensor,
point_offset, point_offset: torch.Tensor,
point_sem, point_sem: torch.Tensor,
coords, coords: torch.Tensor,
indices, indices: torch.Tensor,
cls_label, cls_label: int,
point2line_matching, point2line_matching: torch.Tensor,
corners, corners: torch.Tensor,
center_axises, center_axises: torch.Tensor,
with_yaw, with_yaw: bool,
mode='bottom'): mode: str = 'bottom') -> Tuple:
"""Generate targets of line primitive. """Generate targets of line primitive.
Args: Args:
...@@ -934,15 +950,15 @@ class PrimitiveHead(BaseModule): ...@@ -934,15 +950,15 @@ class PrimitiveHead(BaseModule):
return point_mask, point_offset, point_sem return point_mask, point_offset, point_sem
def _assign_primitive_surface_targets(self, def _assign_primitive_surface_targets(self,
point_mask, point_mask: torch.Tensor,
point_offset, point_offset: torch.Tensor,
point_sem, point_sem: torch.Tensor,
coords, coords: torch.Tensor,
indices, indices: torch.Tensor,
cls_label, cls_label: int,
corners, corners: torch.Tensor,
with_yaw, with_yaw: bool,
mode='bottom'): mode: str = 'bottom') -> Tuple:
"""Generate targets for primitive z and primitive xy. """Generate targets for primitive z and primitive xy.
Args: Args:
...@@ -1017,7 +1033,9 @@ class PrimitiveHead(BaseModule): ...@@ -1017,7 +1033,9 @@ class PrimitiveHead(BaseModule):
point_offset[indices] = center - coords point_offset[indices] = center - coords
return point_mask, point_offset, point_sem return point_mask, point_offset, point_sem
def _get_plane_fomulation(self, vector1, vector2, point): def _get_plane_fomulation(self, vector1: torch.Tensor,
vector2: torch.Tensor,
point: torch.Tensor) -> torch.Tensor:
"""Compute the equation of the plane. """Compute the equation of the plane.
Args: Args:
......
...@@ -90,16 +90,18 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -90,16 +90,18 @@ class PartAggregationROIHead(Base3DRoIHead):
return bbox_results return bbox_results
def _assign_and_sample( def _assign_and_sample(
self, proposal_list: InstanceList, self, rpn_results_list: InstanceList,
batch_gt_instances_3d: InstanceList) -> List[SamplingResult]: batch_gt_instances_3d: InstanceList,
batch_gt_instances_ignore: InstanceList) -> List[SamplingResult]:
"""Assign and sample proposals for training. """Assign and sample proposals for training.
Args: Args:
proposal_list (list[:obj:`InstancesData`]): Proposals produced by rpn_results_list (List[:obj:`InstanceData`]): Detection results
rpn head. of rpn head.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and gt_instances. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes. ``labels_3d`` attributes.
batch_gt_instances_ignore (list): Ignore instances of gt bboxes.
Returns: Returns:
list[:obj:`SamplingResult`]: Sampled results of each training list[:obj:`SamplingResult`]: Sampled results of each training
...@@ -107,16 +109,16 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -107,16 +109,16 @@ class PartAggregationROIHead(Base3DRoIHead):
""" """
sampling_results = [] sampling_results = []
# bbox assign # bbox assign
for batch_idx in range(len(proposal_list)): for batch_idx in range(len(rpn_results_list)):
cur_proposal_list = proposal_list[batch_idx] cur_proposal_list = rpn_results_list[batch_idx]
cur_boxes = cur_proposal_list['bboxes_3d'] cur_boxes = cur_proposal_list['bboxes_3d']
cur_labels_3d = cur_proposal_list['labels_3d'] cur_labels_3d = cur_proposal_list['labels_3d']
cur_gt_instances_3d = batch_gt_instances_3d[batch_idx] cur_gt_instances_3d = batch_gt_instances_3d[batch_idx]
cur_gt_instances_ignore = batch_gt_instances_ignore[batch_idx]
cur_gt_instances_3d.bboxes_3d = cur_gt_instances_3d.\ cur_gt_instances_3d.bboxes_3d = cur_gt_instances_3d.\
bboxes_3d.tensor bboxes_3d.tensor
cur_gt_bboxes = batch_gt_instances_3d[batch_idx].bboxes_3d.to( cur_gt_bboxes = cur_gt_instances_3d.bboxes_3d.to(cur_boxes.device)
cur_boxes.device) cur_gt_labels = cur_gt_instances_3d.labels_3d
cur_gt_labels = batch_gt_instances_3d[batch_idx].labels_3d
batch_num_gts = 0 batch_num_gts = 0
# 0 is bg # 0 is bg
...@@ -132,7 +134,8 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -132,7 +134,8 @@ class PartAggregationROIHead(Base3DRoIHead):
pred_per_cls = (cur_labels_3d == i) pred_per_cls = (cur_labels_3d == i)
cur_assign_res = assigner.assign( cur_assign_res = assigner.assign(
cur_proposal_list[pred_per_cls], cur_proposal_list[pred_per_cls],
cur_gt_instances_3d[gt_per_cls]) cur_gt_instances_3d[gt_per_cls],
cur_gt_instances_ignore)
# gather assign_results in different class into one result # gather assign_results in different class into one result
batch_num_gts += cur_assign_res.num_gts batch_num_gts += cur_assign_res.num_gts
# gt inds (1-based) # gt inds (1-based)
...@@ -158,7 +161,8 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -158,7 +161,8 @@ class PartAggregationROIHead(Base3DRoIHead):
batch_gt_labels) batch_gt_labels)
else: # for single class else: # for single class
assign_result = self.bbox_assigner.assign( assign_result = self.bbox_assigner.assign(
cur_proposal_list, cur_gt_instances_3d) cur_proposal_list, cur_gt_instances_3d,
cur_gt_instances_ignore)
# sample boxes # sample boxes
sampling_result = self.bbox_sampler.sample(assign_result, sampling_result = self.bbox_sampler.sample(assign_result,
cur_boxes.tensor, cur_boxes.tensor,
...@@ -200,7 +204,7 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -200,7 +204,7 @@ class PartAggregationROIHead(Base3DRoIHead):
Args: Args:
feats_dict (dict): Contains features from the first stage. feats_dict (dict): Contains features from the first stage.
rpn_results_list (List[:obj:`InstancesData`]): Detection results rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head. of rpn head.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as samples. It usually includes information such as
...@@ -247,7 +251,7 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -247,7 +251,7 @@ class PartAggregationROIHead(Base3DRoIHead):
voxel_dict (dict): Contains information of voxels. voxel_dict (dict): Contains information of voxels.
batch_input_metas (list[dict], Optional): Batch image meta info. batch_input_metas (list[dict], Optional): Batch image meta info.
Defaults to None. Defaults to None.
rpn_results_list (List[:obj:`InstancesData`]): Detection results rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head. of rpn head.
test_cfg (Config): Test config. test_cfg (Config): Test config.
...@@ -316,7 +320,7 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -316,7 +320,7 @@ class PartAggregationROIHead(Base3DRoIHead):
Args: Args:
feats_dict (dict): Contains features from the first stage. feats_dict (dict): Contains features from the first stage.
rpn_results_list (List[:obj:`InstancesData`]): Detection results rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head. of rpn head.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as samples. It usually includes information such as
...@@ -342,7 +346,8 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -342,7 +346,8 @@ class PartAggregationROIHead(Base3DRoIHead):
losses.update(semantic_results.pop('loss_semantic')) losses.update(semantic_results.pop('loss_semantic'))
sample_results = self._assign_and_sample(rpn_results_list, sample_results = self._assign_and_sample(rpn_results_list,
batch_gt_instances_3d) batch_gt_instances_3d,
batch_gt_instances_ignore)
if self.with_bbox: if self.with_bbox:
feats_dict.update(semantic_results) feats_dict.update(semantic_results)
bbox_results = self._bbox_forward_train(feats_dict, voxels_dict, bbox_results = self._bbox_forward_train(feats_dict, voxels_dict,
...@@ -358,7 +363,7 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -358,7 +363,7 @@ class PartAggregationROIHead(Base3DRoIHead):
Args: Args:
feats_dict (dict): Contains features from the first stage. feats_dict (dict): Contains features from the first stage.
rpn_results_list (List[:obj:`InstancesData`]): Detection results rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head. of rpn head.
Returns: Returns:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
import torch import torch
from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures import bbox3d2result, bbox3d2roi from mmdet3d.structures import bbox3d2roi
from mmdet3d.utils.typing import InstanceList, SampleList
from mmdet.models.task_modules import AssignResult from mmdet.models.task_modules import AssignResult
from .base_3droi_head import Base3DRoIHead from .base_3droi_head import Base3DRoIHead
...@@ -14,43 +18,31 @@ class PointRCNNRoIHead(Base3DRoIHead): ...@@ -14,43 +18,31 @@ class PointRCNNRoIHead(Base3DRoIHead):
Args: Args:
bbox_head (dict): Config of bbox_head. bbox_head (dict): Config of bbox_head.
point_roi_extractor (dict): Config of RoI extractor. bbox_roi_extractor (dict): Config of RoI extractor.
train_cfg (dict): Train configs. train_cfg (dict): Train configs.
test_cfg (dict): Test configs. test_cfg (dict): Test configs.
depth_normalizer (float, optional): Normalize depth feature. depth_normalizer (float): Normalize depth feature.
Defaults to 70.0. Defaults to 70.0.
init_cfg (dict, optional): Config of initialization. Defaults to None. init_cfg (dict, optional): Config of initialization. Defaults to None.
""" """
def __init__(self, def __init__(self,
bbox_head, bbox_head: dict,
point_roi_extractor, bbox_roi_extractor: dict,
train_cfg, train_cfg: dict,
test_cfg, test_cfg: dict,
depth_normalizer=70.0, depth_normalizer: dict = 70.0,
pretrained=None, init_cfg: Optional[dict] = None) -> None:
init_cfg=None):
super(PointRCNNRoIHead, self).__init__( super(PointRCNNRoIHead, self).__init__(
bbox_head=bbox_head, bbox_head=bbox_head,
bbox_roi_extractor=bbox_roi_extractor,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
pretrained=pretrained,
init_cfg=init_cfg) init_cfg=init_cfg)
self.depth_normalizer = depth_normalizer self.depth_normalizer = depth_normalizer
if point_roi_extractor is not None:
self.point_roi_extractor = MODELS.build(point_roi_extractor)
self.init_assigner_sampler() self.init_assigner_sampler()
def init_bbox_head(self, bbox_head):
"""Initialize box head.
Args:
bbox_head (dict): Config dict of RoI Head.
"""
self.bbox_head = MODELS.build(bbox_head)
def init_mask_head(self): def init_mask_head(self):
"""Initialize maek head.""" """Initialize maek head."""
pass pass
...@@ -68,77 +60,101 @@ class PointRCNNRoIHead(Base3DRoIHead): ...@@ -68,77 +60,101 @@ class PointRCNNRoIHead(Base3DRoIHead):
] ]
self.bbox_sampler = TASK_UTILS.build(self.train_cfg.sampler) self.bbox_sampler = TASK_UTILS.build(self.train_cfg.sampler)
def forward_train(self, feats_dict, input_metas, proposal_list, def loss(self, feats_dict: Dict, rpn_results_list: InstanceList,
gt_bboxes_3d, gt_labels_3d): batch_data_samples: SampleList, **kwargs) -> dict:
"""Training forward function of PointRCNNRoIHead. """Perform forward propagation and loss calculation of the detection
roi on the features of the upstream network.
Args: Args:
feats_dict (dict): Contains features from the first stage. feats_dict (dict): Contains features from the first stage.
imput_metas (list[dict]): Meta info of each input. rpn_results_list (List[:obj:`InstanceData`]): Detection results
proposal_list (list[dict]): Proposal information from rpn. of rpn head.
The dictionary should contain the following keys: batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
- boxes_3d (:obj:`BaseInstance3DBoxes`): Proposal bboxes `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
- labels_3d (torch.Tensor): Labels of proposals
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]):
GT bboxes of each sample. The bboxes are encapsulated
by 3D box bboxes_3d.
gt_labels_3d (list[LongTensor]): GT labels of each sample.
Returns: Returns:
dict: Losses from RoI RCNN head. dict[str, Tensor]: A dictionary of loss components
- loss_bbox (torch.Tensor): Loss of bboxes
""" """
features = feats_dict['features'] features = feats_dict['fp_features']
points = feats_dict['points'] fp_points = feats_dict['fp_points']
point_cls_preds = feats_dict['points_cls_preds'] point_cls_preds = feats_dict['points_cls_preds']
sem_scores = point_cls_preds.sigmoid() sem_scores = point_cls_preds.sigmoid()
point_scores = sem_scores.max(-1)[0] point_scores = sem_scores.max(-1)[0]
batch_gt_instances_3d = []
sample_results = self._assign_and_sample(proposal_list, gt_bboxes_3d, batch_gt_instances_ignore = []
gt_labels_3d) for data_sample in batch_data_samples:
batch_gt_instances_3d.append(data_sample.gt_instances_3d)
if 'ignored_instances' in data_sample:
batch_gt_instances_ignore.append(data_sample.ignored_instances)
else:
batch_gt_instances_ignore.append(None)
sample_results = self._assign_and_sample(rpn_results_list,
batch_gt_instances_3d,
batch_gt_instances_ignore)
# concat the depth, semantic features and backbone features # concat the depth, semantic features and backbone features
features = features.transpose(1, 2).contiguous() features = features.transpose(1, 2).contiguous()
point_depths = points.norm(dim=2) / self.depth_normalizer - 0.5 point_depths = fp_points.norm(dim=2) / self.depth_normalizer - 0.5
features_list = [ features_list = [
point_scores.unsqueeze(2), point_scores.unsqueeze(2),
point_depths.unsqueeze(2), features point_depths.unsqueeze(2), features
] ]
features = torch.cat(features_list, dim=2) features = torch.cat(features_list, dim=2)
bbox_results = self._bbox_forward_train(features, points, bbox_results = self._bbox_forward_train(features, fp_points,
sample_results) sample_results)
losses = dict() losses = dict()
losses.update(bbox_results['loss_bbox']) losses.update(bbox_results['loss_bbox'])
return losses return losses
def simple_test(self, feats_dict, img_metas, proposal_list, **kwargs): def predict(self,
"""Simple testing forward function of PointRCNNRoIHead. feats_dict: Dict,
rpn_results_list: InstanceList,
Note: batch_data_samples: SampleList,
This function assumes that the batch size is 1 rescale: bool = False,
**kwargs) -> InstanceList:
"""Perform forward propagation of the roi head and predict detection
results on the features of the upstream network.
Args: Args:
feats_dict (dict): Contains features from the first stage. feats_dict (dict): Contains features from the first stage.
img_metas (list[dict]): Meta info of each image. rpn_results_list (List[:obj:`InstanceData`]): Detection results
proposal_list (list[dict]): Proposal information from rpn. of rpn head.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
Returns: Returns:
dict: Bbox results of one frame. list[:obj:`InstanceData`]: Detection results of each sample
after the post process.
Each item usually contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes,
contains a tensor with shape (num_instances, C), where
C >= 7.
""" """
rois = bbox3d2roi([res['boxes_3d'].tensor for res in proposal_list]) rois = bbox3d2roi(
labels_3d = [res['labels_3d'] for res in proposal_list] [res['bboxes_3d'].tensor for res in rpn_results_list])
labels_3d = [res['labels_3d'] for res in rpn_results_list]
features = feats_dict['features'] batch_input_metas = [
points = feats_dict['points'] data_samples.metainfo for data_samples in batch_data_samples
]
fp_features = feats_dict['fp_features']
fp_points = feats_dict['fp_points']
point_cls_preds = feats_dict['points_cls_preds'] point_cls_preds = feats_dict['points_cls_preds']
sem_scores = point_cls_preds.sigmoid() sem_scores = point_cls_preds.sigmoid()
point_scores = sem_scores.max(-1)[0] point_scores = sem_scores.max(-1)[0]
features = features.transpose(1, 2).contiguous() features = fp_features.transpose(1, 2).contiguous()
point_depths = points.norm(dim=2) / self.depth_normalizer - 0.5 point_depths = fp_points.norm(dim=2) / self.depth_normalizer - 0.5
features_list = [ features_list = [
point_scores.unsqueeze(2), point_scores.unsqueeze(2),
point_depths.unsqueeze(2), features point_depths.unsqueeze(2), features
...@@ -146,29 +162,27 @@ class PointRCNNRoIHead(Base3DRoIHead): ...@@ -146,29 +162,27 @@ class PointRCNNRoIHead(Base3DRoIHead):
features = torch.cat(features_list, dim=2) features = torch.cat(features_list, dim=2)
batch_size = features.shape[0] batch_size = features.shape[0]
bbox_results = self._bbox_forward(features, points, batch_size, rois) bbox_results = self._bbox_forward(features, fp_points, batch_size,
rois)
object_score = bbox_results['cls_score'].sigmoid() object_score = bbox_results['cls_score'].sigmoid()
bbox_list = self.bbox_head.get_bboxes( bbox_list = self.bbox_head.get_results(
rois, rois,
object_score, object_score,
bbox_results['bbox_pred'], bbox_results['bbox_pred'],
labels_3d, labels_3d,
img_metas, batch_input_metas,
cfg=self.test_cfg) cfg=self.test_cfg)
bbox_results = [ return bbox_list
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results
def _bbox_forward_train(self, features, points, sampling_results): def _bbox_forward_train(self, features: Tensor, points: Tensor,
sampling_results: SampleList) -> dict:
"""Forward training function of roi_extractor and bbox_head. """Forward training function of roi_extractor and bbox_head.
Args: Args:
features (torch.Tensor): Backbone features with depth and \ features (torch.Tensor): Backbone features with depth and \
semantic features. semantic features.
points (torch.Tensor): Pointcloud. points (torch.Tensor): Point cloud.
sampling_results (:obj:`SamplingResult`): Sampled results used sampling_results (:obj:`SamplingResult`): Sampled results used
for training. for training.
...@@ -188,14 +202,15 @@ class PointRCNNRoIHead(Base3DRoIHead): ...@@ -188,14 +202,15 @@ class PointRCNNRoIHead(Base3DRoIHead):
bbox_results.update(loss_bbox=loss_bbox) bbox_results.update(loss_bbox=loss_bbox)
return bbox_results return bbox_results
def _bbox_forward(self, features, points, batch_size, rois): def _bbox_forward(self, features: Tensor, points: Tensor, batch_size: int,
rois: Tensor) -> dict:
"""Forward function of roi_extractor and bbox_head used in both """Forward function of roi_extractor and bbox_head used in both
training and testing. training and testing.
Args: Args:
features (torch.Tensor): Backbone features with depth and features (torch.Tensor): Backbone features with depth and
semantic features. semantic features.
points (torch.Tensor): Pointcloud. points (torch.Tensor): Point cloud.
batch_size (int): Batch size. batch_size (int): Batch size.
rois (torch.Tensor): RoI boxes. rois (torch.Tensor): RoI boxes.
...@@ -203,21 +218,27 @@ class PointRCNNRoIHead(Base3DRoIHead): ...@@ -203,21 +218,27 @@ class PointRCNNRoIHead(Base3DRoIHead):
dict: Contains predictions of bbox_head and dict: Contains predictions of bbox_head and
features of roi_extractor. features of roi_extractor.
""" """
pooled_point_feats = self.point_roi_extractor(features, points, pooled_point_feats = self.bbox_roi_extractor(features, points,
batch_size, rois) batch_size, rois)
cls_score, bbox_pred = self.bbox_head(pooled_point_feats) cls_score, bbox_pred = self.bbox_head(pooled_point_feats)
bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred) bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred)
return bbox_results return bbox_results
def _assign_and_sample(self, proposal_list, gt_bboxes_3d, gt_labels_3d): def _assign_and_sample(
self, rpn_results_list: InstanceList,
batch_gt_instances_3d: InstanceList,
batch_gt_instances_ignore: InstanceList) -> SampleList:
"""Assign and sample proposals for training. """Assign and sample proposals for training.
Args: Args:
proposal_list (list[dict]): Proposals produced by RPN. rpn_results_list (List[:obj:`InstanceData`]): Detection results
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth of rpn head.
boxes. batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_labels_3d (list[torch.Tensor]): Ground truth labels gt_instances. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
batch_gt_instances_ignore (list[:obj:`InstanceData`]): Ignore
instances of gt bboxes.
Returns: Returns:
list[:obj:`SamplingResult`]: Sampled results of each training list[:obj:`SamplingResult`]: Sampled results of each training
...@@ -225,12 +246,16 @@ class PointRCNNRoIHead(Base3DRoIHead): ...@@ -225,12 +246,16 @@ class PointRCNNRoIHead(Base3DRoIHead):
""" """
sampling_results = [] sampling_results = []
# bbox assign # bbox assign
for batch_idx in range(len(proposal_list)): for batch_idx in range(len(rpn_results_list)):
cur_proposal_list = proposal_list[batch_idx] cur_proposal_list = rpn_results_list[batch_idx]
cur_boxes = cur_proposal_list['boxes_3d'] cur_boxes = cur_proposal_list['bboxes_3d']
cur_labels_3d = cur_proposal_list['labels_3d'] cur_labels_3d = cur_proposal_list['labels_3d']
cur_gt_bboxes = gt_bboxes_3d[batch_idx].to(cur_boxes.device) cur_gt_instances_3d = batch_gt_instances_3d[batch_idx]
cur_gt_labels = gt_labels_3d[batch_idx] cur_gt_instances_3d.bboxes_3d = cur_gt_instances_3d.\
bboxes_3d.tensor
cur_gt_instances_ignore = batch_gt_instances_ignore[batch_idx]
cur_gt_bboxes = cur_gt_instances_3d.bboxes_3d.to(cur_boxes.device)
cur_gt_labels = cur_gt_instances_3d.labels_3d
batch_num_gts = 0 batch_num_gts = 0
# 0 is bg # 0 is bg
batch_gt_indis = cur_gt_labels.new_full((len(cur_boxes), ), 0) batch_gt_indis = cur_gt_labels.new_full((len(cur_boxes), ), 0)
...@@ -244,9 +269,9 @@ class PointRCNNRoIHead(Base3DRoIHead): ...@@ -244,9 +269,9 @@ class PointRCNNRoIHead(Base3DRoIHead):
gt_per_cls = (cur_gt_labels == i) gt_per_cls = (cur_gt_labels == i)
pred_per_cls = (cur_labels_3d == i) pred_per_cls = (cur_labels_3d == i)
cur_assign_res = assigner.assign( cur_assign_res = assigner.assign(
cur_boxes.tensor[pred_per_cls], cur_proposal_list[pred_per_cls],
cur_gt_bboxes.tensor[gt_per_cls], cur_gt_instances_3d[gt_per_cls],
gt_labels=cur_gt_labels[gt_per_cls]) cur_gt_instances_ignore)
# gather assign_results in different class into one result # gather assign_results in different class into one result
batch_num_gts += cur_assign_res.num_gts batch_num_gts += cur_assign_res.num_gts
# gt inds (1-based) # gt inds (1-based)
...@@ -272,14 +297,13 @@ class PointRCNNRoIHead(Base3DRoIHead): ...@@ -272,14 +297,13 @@ class PointRCNNRoIHead(Base3DRoIHead):
batch_gt_labels) batch_gt_labels)
else: # for single class else: # for single class
assign_result = self.bbox_assigner.assign( assign_result = self.bbox_assigner.assign(
cur_boxes.tensor, cur_proposal_list, cur_gt_instances_3d,
cur_gt_bboxes.tensor, cur_gt_instances_ignore)
gt_labels=cur_gt_labels)
# sample boxes # sample boxes
sampling_result = self.bbox_sampler.sample(assign_result, sampling_result = self.bbox_sampler.sample(assign_result,
cur_boxes.tensor, cur_boxes.tensor,
cur_gt_bboxes.tensor, cur_gt_bboxes,
cur_gt_labels) cur_gt_labels)
sampling_results.append(sampling_result) sampling_results.append(sampling_result)
return sampling_results return sampling_results
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch import torch
import torch.nn as nn
from mmcv import ops from mmcv import ops
from mmengine.model import BaseModule from mmengine.model import BaseModule
from torch import Tensor
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
...@@ -13,14 +17,16 @@ class Single3DRoIAwareExtractor(BaseModule): ...@@ -13,14 +17,16 @@ class Single3DRoIAwareExtractor(BaseModule):
Extract Point-wise roi features. Extract Point-wise roi features.
Args: Args:
roi_layer (dict): The config of roi layer. roi_layer (dict, optional): The config of roi layer.
""" """
def __init__(self, roi_layer=None, init_cfg=None): def __init__(self,
roi_layer: Optional[dict] = None,
init_cfg: Optional[dict] = None) -> None:
super(Single3DRoIAwareExtractor, self).__init__(init_cfg=init_cfg) super(Single3DRoIAwareExtractor, self).__init__(init_cfg=init_cfg)
self.roi_layer = self.build_roi_layers(roi_layer) self.roi_layer = self.build_roi_layers(roi_layer)
def build_roi_layers(self, layer_cfg): def build_roi_layers(self, layer_cfg: dict) -> nn.Module:
"""Build roi layers using `layer_cfg`""" """Build roi layers using `layer_cfg`"""
cfg = layer_cfg.copy() cfg = layer_cfg.copy()
layer_type = cfg.pop('type') layer_type = cfg.pop('type')
...@@ -29,7 +35,8 @@ class Single3DRoIAwareExtractor(BaseModule): ...@@ -29,7 +35,8 @@ class Single3DRoIAwareExtractor(BaseModule):
roi_layers = layer_cls(**cfg) roi_layers = layer_cls(**cfg)
return roi_layers return roi_layers
def forward(self, feats, coordinate, batch_inds, rois): def forward(self, feats: Tensor, coordinate: Tensor, batch_inds: Tensor,
rois: Tensor) -> Tensor:
"""Extract point-wise roi features. """Extract point-wise roi features.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch import torch
import torch.nn as nn
from mmcv import ops from mmcv import ops
from torch import nn as nn from torch import Tensor
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures.bbox_3d import rotation_3d_in_axis from mmdet3d.structures.bbox_3d import rotation_3d_in_axis
...@@ -14,14 +17,14 @@ class Single3DRoIPointExtractor(nn.Module): ...@@ -14,14 +17,14 @@ class Single3DRoIPointExtractor(nn.Module):
Extract Point-wise roi features. Extract Point-wise roi features.
Args: Args:
roi_layer (dict): The config of roi layer. roi_layer (dict, optional): The config of roi layer.
""" """
def __init__(self, roi_layer=None): def __init__(self, roi_layer: Optional[dict] = None) -> None:
super(Single3DRoIPointExtractor, self).__init__() super(Single3DRoIPointExtractor, self).__init__()
self.roi_layer = self.build_roi_layers(roi_layer) self.roi_layer = self.build_roi_layers(roi_layer)
def build_roi_layers(self, layer_cfg): def build_roi_layers(self, layer_cfg: dict) -> nn.Module:
"""Build roi layers using `layer_cfg`""" """Build roi layers using `layer_cfg`"""
cfg = layer_cfg.copy() cfg = layer_cfg.copy()
layer_type = cfg.pop('type') layer_type = cfg.pop('type')
...@@ -30,7 +33,8 @@ class Single3DRoIPointExtractor(nn.Module): ...@@ -30,7 +33,8 @@ class Single3DRoIPointExtractor(nn.Module):
roi_layers = layer_cls(**cfg) roi_layers = layer_cls(**cfg)
return roi_layers return roi_layers
def forward(self, feats, coordinate, batch_inds, rois): def forward(self, feats: Tensor, coordinate: Tensor, batch_inds: Tensor,
rois: Tensor) -> Tensor:
"""Extract point-wise roi features. """Extract point-wise roi features.
Args: Args:
......
...@@ -5,6 +5,7 @@ from typing import List, Optional, Union ...@@ -5,6 +5,7 @@ from typing import List, Optional, Union
from mmengine.config import ConfigDict from mmengine.config import ConfigDict
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from mmdet3d.structures.det3d_data_sample import Det3DDataSample
from mmdet.models.task_modules.samplers import SamplingResult from mmdet.models.task_modules.samplers import SamplingResult
# Type hint of config data # Type hint of config data
...@@ -21,3 +22,4 @@ OptInstanceList = Optional[InstanceList] ...@@ -21,3 +22,4 @@ OptInstanceList = Optional[InstanceList]
SamplingResultList = List[SamplingResult] SamplingResultList = List[SamplingResult]
OptSamplingResultList = Optional[SamplingResultList] OptSamplingResultList = Optional[SamplingResultList]
SampleList = List[Det3DDataSample]
import unittest
import torch
from mmengine import DefaultScope
from mmdet3d.registry import MODELS
from tests.utils.model_utils import (_create_detector_inputs,
_get_detector_cfg, _setup_seed)
class TestPointRCNN(unittest.TestCase):
def test_pointrcnn(self):
import mmdet3d.models
assert hasattr(mmdet3d.models, 'PointRCNN')
DefaultScope.get_instance('test_pointrcnn', scope_name='mmdet3d')
_setup_seed(0)
pointrcnn_cfg = _get_detector_cfg(
'point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py')
model = MODELS.build(pointrcnn_cfg)
num_gt_instance = 2
packed_inputs = _create_detector_inputs(
num_points=10101, num_gt_instance=num_gt_instance)
if torch.cuda.is_available():
model = model.cuda()
# test simple_test
with torch.no_grad():
data = model.data_preprocessor(packed_inputs, True)
torch.cuda.empty_cache()
results = model.forward(**data, mode='predict')
self.assertEqual(len(results), 1)
self.assertIn('bboxes_3d', results[0].pred_instances_3d)
self.assertIn('scores_3d', results[0].pred_instances_3d)
self.assertIn('labels_3d', results[0].pred_instances_3d)
# save the memory
with torch.no_grad():
losses = model.forward(**data, mode='loss')
torch.cuda.empty_cache()
self.assertGreaterEqual(losses['rpn_bbox_loss'], 0)
self.assertGreaterEqual(losses['rpn_semantic_loss'], 0)
self.assertGreaterEqual(losses['loss_cls'], 0)
self.assertGreaterEqual(losses['loss_bbox'], 0)
self.assertGreaterEqual(losses['loss_corner'], 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment