Commit 0e17beab authored by jshilong's avatar jshilong Committed by ChaimZhu
Browse files

[REfactor]Refactor H3D

parent 9ebb75da
......@@ -30,7 +30,7 @@ primitive_z_cfg = dict(
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
type='CrossEntropyLoss',
type='mmdet.CrossEntropyLoss',
class_weight=[0.4, 0.6],
reduction='mean',
loss_weight=30.0),
......@@ -47,14 +47,16 @@ primitive_z_cfg = dict(
loss_src_weight=0.5,
loss_dst_weight=0.5),
semantic_cls_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
train_cfg=dict(
sample_mode='vote',
dist_thresh=0.2,
var_thresh=1e-2,
lower_thresh=1e-6,
num_point=100,
num_point_line=10,
line_thresh=0.2))
line_thresh=0.2),
test_cfg=dict(sample_mode='seed'))
primitive_xy_cfg = dict(
type='PrimitiveHead',
......@@ -88,7 +90,7 @@ primitive_xy_cfg = dict(
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
type='CrossEntropyLoss',
type='mmdet.CrossEntropyLoss',
class_weight=[0.4, 0.6],
reduction='mean',
loss_weight=30.0),
......@@ -105,14 +107,16 @@ primitive_xy_cfg = dict(
loss_src_weight=0.5,
loss_dst_weight=0.5),
semantic_cls_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
train_cfg=dict(
sample_mode='vote',
dist_thresh=0.2,
var_thresh=1e-2,
lower_thresh=1e-6,
num_point=100,
num_point_line=10,
line_thresh=0.2))
line_thresh=0.2),
test_cfg=dict(sample_mode='seed'))
primitive_line_cfg = dict(
type='PrimitiveHead',
......@@ -146,7 +150,7 @@ primitive_line_cfg = dict(
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
type='CrossEntropyLoss',
type='mmdet.CrossEntropyLoss',
class_weight=[0.4, 0.6],
reduction='mean',
loss_weight=30.0),
......@@ -163,17 +167,20 @@ primitive_line_cfg = dict(
loss_src_weight=1.0,
loss_dst_weight=1.0),
semantic_cls_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=2.0),
type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=2.0),
train_cfg=dict(
sample_mode='vote',
dist_thresh=0.2,
var_thresh=1e-2,
lower_thresh=1e-6,
num_point=100,
num_point_line=10,
line_thresh=0.2))
line_thresh=0.2),
test_cfg=dict(sample_mode='seed'))
model = dict(
type='H3DNet',
data_preprocessor=dict(type='Det3DDataPreprocessor'),
backbone=dict(
type='MultiBackbone',
num_streams=4,
......@@ -221,10 +228,8 @@ model = dict(
normalize_xyz=True),
pred_layer_cfg=dict(
in_channels=128, shared_conv_channels=(128, 128), bias=True),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
type='CrossEntropyLoss',
type='mmdet.CrossEntropyLoss',
class_weight=[0.2, 0.8],
reduction='sum',
loss_weight=5.0),
......@@ -235,15 +240,15 @@ model = dict(
loss_src_weight=10.0,
loss_dst_weight=10.0),
dir_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
dir_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0),
type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=10.0),
size_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
size_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0),
type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=10.0),
semantic_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0)),
type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0)),
roi_head=dict(
type='H3DRoIHead',
primitive_list=[primitive_z_cfg, primitive_xy_cfg, primitive_line_cfg],
......@@ -267,7 +272,6 @@ model = dict(
mlp_channels=[128 + 12, 128, 64, 32],
use_xyz=True,
normalize_xyz=True),
feat_channels=(128, 128),
primitive_refine_channels=[128, 128, 128],
upper_thresh=100.0,
surface_thresh=0.5,
......@@ -275,7 +279,7 @@ model = dict(
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
type='CrossEntropyLoss',
type='mmdet.CrossEntropyLoss',
class_weight=[0.2, 0.8],
reduction='sum',
loss_weight=5.0),
......@@ -286,41 +290,47 @@ model = dict(
loss_src_weight=10.0,
loss_dst_weight=10.0),
dir_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1),
type='mmdet.CrossEntropyLoss',
reduction='sum',
loss_weight=0.1),
dir_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0),
type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=10.0),
size_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1),
type='mmdet.CrossEntropyLoss',
reduction='sum',
loss_weight=0.1),
size_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0),
type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=10.0),
semantic_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1),
type='mmdet.CrossEntropyLoss',
reduction='sum',
loss_weight=0.1),
cues_objectness_loss=dict(
type='CrossEntropyLoss',
type='mmdet.CrossEntropyLoss',
class_weight=[0.3, 0.7],
reduction='mean',
loss_weight=5.0),
cues_semantic_loss=dict(
type='CrossEntropyLoss',
type='mmdet.CrossEntropyLoss',
class_weight=[0.3, 0.7],
reduction='mean',
loss_weight=5.0),
proposal_objectness_loss=dict(
type='CrossEntropyLoss',
type='mmdet.CrossEntropyLoss',
class_weight=[0.2, 0.8],
reduction='none',
loss_weight=5.0),
primitive_center_loss=dict(
type='MSELoss', reduction='none', loss_weight=1.0))),
type='mmdet.MSELoss', reduction='none', loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mod='vote'),
pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mode='vote'),
rpn_proposal=dict(use_nms=False),
rcnn=dict(
pos_distance_thr=0.3,
neg_distance_thr=0.6,
sample_mod='vote',
sample_mode='vote',
far_threshold=0.6,
near_threshold=0.3,
mask_surface_threshold=0.3,
......@@ -329,13 +339,13 @@ model = dict(
label_line_threshold=0.3)),
test_cfg=dict(
rpn=dict(
sample_mod='seed',
sample_mode='seed',
nms_thr=0.25,
score_thr=0.05,
per_class_proposal=True,
use_nms=False),
rcnn=dict(
sample_mod='seed',
sample_mode='seed',
nms_thr=0.25,
score_thr=0.05,
per_class_proposal=True)))
_base_ = [
'../_base_/datasets/scannet-3d-18class.py', '../_base_/models/h3dnet.py',
'../_base_/schedules/schedule_3x.py', '../_base_/default_runtime.py'
]
# model settings
model = dict(
rpn_head=dict(
num_classes=18,
bbox_coder=dict(
type='PartialBinBasedBBoxCoder',
num_sizes=18,
num_dir_bins=24,
with_rot=False,
mean_sizes=[[0.76966727, 0.8116021, 0.92573744],
[1.876858, 1.8425595, 1.1931566],
[0.61328, 0.6148609, 0.7182701],
[1.3955007, 1.5121545, 0.83443564],
[0.97949594, 1.0675149, 0.6329687],
[0.531663, 0.5955577, 1.7500148],
[0.9624706, 0.72462326, 1.1481868],
[0.83221924, 1.0490936, 1.6875663],
[0.21132214, 0.4206159, 0.5372846],
[1.4440073, 1.8970833, 0.26985747],
[1.0294262, 1.4040797, 0.87554324],
[1.3766412, 0.65521795, 1.6813129],
[0.6650819, 0.71111923, 1.298853],
[0.41999173, 0.37906948, 1.7513971],
[0.59359556, 0.5912492, 0.73919016],
[0.50867593, 0.50656086, 0.30136237],
[1.1511526, 1.0546296, 0.49706793],
[0.47535285, 0.49249494, 0.5802117]])),
roi_head=dict(
bbox_head=dict(
num_classes=18,
bbox_coder=dict(
type='PartialBinBasedBBoxCoder',
num_sizes=18,
num_dir_bins=24,
with_rot=False,
mean_sizes=[[0.76966727, 0.8116021, 0.92573744],
[1.876858, 1.8425595, 1.1931566],
[0.61328, 0.6148609, 0.7182701],
[1.3955007, 1.5121545, 0.83443564],
[0.97949594, 1.0675149, 0.6329687],
[0.531663, 0.5955577, 1.7500148],
[0.9624706, 0.72462326, 1.1481868],
[0.83221924, 1.0490936, 1.6875663],
[0.21132214, 0.4206159, 0.5372846],
[1.4440073, 1.8970833, 0.26985747],
[1.0294262, 1.4040797, 0.87554324],
[1.3766412, 0.65521795, 1.6813129],
[0.6650819, 0.71111923, 1.298853],
[0.41999173, 0.37906948, 1.7513971],
[0.59359556, 0.5912492, 0.73919016],
[0.50867593, 0.50656086, 0.30136237],
[1.1511526, 1.0546296, 0.49706793],
[0.47535285, 0.49249494, 0.5802117]]))))
train_dataloader = dict(
batch_size=3,
num_workers=2,
)
# yapf:disable
default_hooks = dict(
logger=dict(type='LoggerHook', interval=30)
)
# yapf:enable
......@@ -57,8 +57,13 @@ model = dict(
[1.1511526, 1.0546296, 0.49706793],
[0.47535285, 0.49249494, 0.5802117]]))))
data = dict(samples_per_gpu=3, workers_per_gpu=2)
train_dataloader = dict(
batch_size=3,
num_workers=2,
)
# yapf:disable
log_config = dict(interval=30)
default_hooks = dict(
logger=dict(type='LoggerHook', interval=30)
)
# yapf:enable
......@@ -5,11 +5,9 @@ from typing import Callable, List, Optional, Union
import numpy as np
from mmdet3d.core import show_result
from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.registry import DATASETS
from .det3d_dataset import Det3DDataset
from .pipelines import Compose
from .seg3d_dataset import Seg3DDataset
......@@ -151,46 +149,6 @@ class ScanNetDataset(Det3DDataset):
return ann_info
def _build_default_pipeline(self):
"""Build the default pipeline for this dataset."""
pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='DefaultFormatBundle3D',
class_names=self.CLASSES,
with_label=False),
dict(type='Collect3D', keys=['points'])
]
return Compose(pipeline)
def show(self, results, out_dir, show=True, pipeline=None):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
show (bool): Visualize the results online.
pipeline (list[dict], optional): raw data loading for showing.
Default: None.
"""
assert out_dir is not None, 'Expect out_dir, got none.'
pipeline = self._get_pipeline(pipeline)
for i, result in enumerate(results):
data_info = self.get_data_info[i]
pts_path = data_info['lidar_points']['lidar_path']
file_name = osp.split(pts_path)[-1].split('.')[0]
points = self._extract_data(i, pipeline, 'points').numpy()
gt_bboxes = self.get_ann_info(i)['gt_bboxes_3d'].tensor.numpy()
pred_bboxes = result['boxes_3d'].tensor.numpy()
show_result(points, gt_bboxes, pred_bboxes, out_dir, file_name,
show)
@DATASETS.register_module()
class ScanNetSegDataset(Seg3DDataset):
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from mmcv.ops import furthest_point_sample
from mmcv.runner import BaseModule
from mmengine import ConfigDict, InstanceData
from torch import Tensor
from torch.nn import functional as F
from mmdet3d.core.post_processing import aligned_3d_nms
......@@ -161,7 +162,7 @@ class VoteHead(BaseModule):
points: List[torch.Tensor],
feats_dict: Dict[str, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
rescale=True,
use_nms: bool = True,
**kwargs) -> List[InstanceData]:
"""
Args:
......@@ -169,8 +170,8 @@ class VoteHead(BaseModule):
feats_dict (dict): Features from FPN or backbone..
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes meta information of data.
rescale (bool): Whether rescale the resutls to
the original scale.
use_nms (bool): Whether do the nms for predictions.
Defaults to True.
Returns:
list[:obj:`InstanceData`]: List of processed predictions. Each
......@@ -178,6 +179,9 @@ class VoteHead(BaseModule):
scores and labels.
"""
preds_dict = self(feats_dict)
# `preds_dict` can be used in H3DNET
feats_dict.update(preds_dict)
batch_size = len(batch_data_samples)
batch_input_metas = []
for batch_index in range(batch_size):
......@@ -185,12 +189,73 @@ class VoteHead(BaseModule):
batch_input_metas.append(metainfo)
results_list = self.predict_by_feat(
points, preds_dict, batch_input_metas, rescale=rescale, **kwargs)
points, preds_dict, batch_input_metas, use_nms=use_nms, **kwargs)
return results_list
def loss(self, points: List[torch.Tensor], feats_dict: Dict[str,
torch.Tensor],
batch_data_samples: List[Det3DDataSample], **kwargs) -> dict:
def loss_and_predict(self,
points: List[torch.Tensor],
feats_dict: Dict[str, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
ret_target: bool = False,
proposal_cfg: dict = None,
**kwargs) -> Tuple:
"""
Args:
points (list[tensor]): Points cloud of multiple samples.
feats_dict (dict): Predictions from backbone or FPN.
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each sample and
corresponding annotations.
ret_target (bool): Whether return the assigned target.
Defaults to False.
proposal_cfg (dict): Configure for proposal process.
Defaults to True.
Returns:
tuple: Contains loss and predictions after post-process.
"""
preds_dict = self.forward(feats_dict)
feats_dict.update(preds_dict)
batch_gt_instance_3d = []
batch_gt_instances_ignore = []
batch_input_metas = []
batch_pts_semantic_mask = []
batch_pts_instance_mask = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
batch_gt_instance_3d.append(data_sample.gt_instances_3d)
batch_gt_instances_ignore.append(
data_sample.get('ignored_instances', None))
batch_pts_semantic_mask.append(
data_sample.gt_pts_seg.get('pts_semantic_mask', None))
batch_pts_instance_mask.append(
data_sample.gt_pts_seg.get('pts_instance_mask', None))
loss_inputs = (points, preds_dict, batch_gt_instance_3d)
losses = self.loss_by_feat(
*loss_inputs,
batch_pts_semantic_mask=batch_pts_semantic_mask,
batch_pts_instance_mask=batch_pts_instance_mask,
batch_input_metas=batch_input_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore,
ret_target=ret_target,
**kwargs)
results_list = self.predict_by_feat(
points,
preds_dict,
batch_input_metas,
use_nms=proposal_cfg.use_nms,
**kwargs)
return losses, results_list
def loss(self,
points: List[torch.Tensor],
feats_dict: Dict[str, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
ret_target: bool = False,
**kwargs) -> dict:
"""
Args:
points (list[tensor]): Points cloud of multiple samples.
......@@ -198,6 +263,8 @@ class VoteHead(BaseModule):
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each sample and
corresponding annotations.
ret_target (bool): Whether return the assigned target.
Defaults to False.
Returns:
dict: A dictionary of loss components.
......@@ -224,7 +291,9 @@ class VoteHead(BaseModule):
batch_pts_semantic_mask=batch_pts_semantic_mask,
batch_pts_instance_mask=batch_pts_instance_mask,
batch_input_metas=batch_input_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore)
batch_gt_instances_ignore=batch_gt_instances_ignore,
ret_target=ret_target,
**kwargs)
return losses
def forward(self, feat_dict: dict) -> dict:
......@@ -330,7 +399,7 @@ class VoteHead(BaseModule):
batch_pts_semantic_mask (list[tensor]): Instance mask
of points cloud. Defaults to None.
batch_input_metas (list[dict]): Contain pcd and img's meta info.
ret_target (bool): Return targets or not.
ret_target (bool): Return targets or not. Defaults to False.
Returns:
dict: Losses of Votenet.
......@@ -671,9 +740,10 @@ class VoteHead(BaseModule):
while using vote head in rpn stage.
Returns:
list[:obj:`InstanceData`]: List of processed predictions. Each
InstanceData cantains 3d Bounding boxes and corresponding
scores and labels.
list[:obj:`InstanceData`] or Tensor: Return list of processed
predictions when `use_nms` is True. Each InstanceData cantains
3d Bounding boxes and corresponding scores and labels.
Return raw bboxes when `use_nms` is False.
"""
# decode boxes
stack_points = torch.stack(points)
......@@ -683,9 +753,9 @@ class VoteHead(BaseModule):
batch_size = bbox3d.shape[0]
results_list = list()
for b in range(batch_size):
temp_results = InstanceData()
if use_nms:
if use_nms:
for b in range(batch_size):
temp_results = InstanceData()
bbox_selected, score_selected, labels = \
self.multiclass_nms_single(obj_scores[b],
sem_scores[b],
......@@ -700,20 +770,15 @@ class VoteHead(BaseModule):
temp_results.scores_3d = score_selected
temp_results.labels_3d = labels
results_list.append(temp_results)
else:
bbox = batch_input_metas[b]['box_type_3d'](
bbox_selected,
box_dim=bbox_selected.shape[-1],
with_yaw=self.bbox_coder.with_rot)
temp_results.bboxes_3d = bbox
temp_results.obj_scores_3d = obj_scores[b]
temp_results.sem_scores_3d = obj_scores[b]
results_list.append(temp_results)
return results_list
return results_list
else:
# TODO unify it when refactor the Augtest
return bbox3d
def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points,
input_meta):
def multiclass_nms_single(self, obj_scores: Tensor, sem_scores: Tensor,
bbox: Tensor, points: Tensor,
input_meta: dict) -> Tuple:
"""Multi-class nms in single batch.
Args:
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch
from torch import Tensor
from mmdet3d.core import merge_aug_bboxes_3d
from mmdet3d.registry import MODELS
from ...core import Det3DDataSample
from .two_stage import TwoStage3DDetector
......@@ -11,17 +14,33 @@ class H3DNet(TwoStage3DDetector):
r"""H3DNet model.
Please refer to the `paper <https://arxiv.org/abs/2006.05682>`_
Args:
backbone (dict): Config dict of detector's backbone.
neck (dict, optional): Config dict of neck. Defaults to None.
rpn_head (dict, optional): Config dict of rpn head. Defaults to None.
roi_head (dict, optional): Config dict of roi head. Defaults to None.
train_cfg (dict, optional): Config dict of training hyper-parameters.
Defaults to None.
test_cfg (dict, optional): Config dict of test hyper-parameters.
Defaults to None.
init_cfg (dict, optional): the config to control the
initialization. Default to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`BaseDataPreprocessor`. it usually includes,
``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
"""
def __init__(self,
backbone,
neck=None,
rpn_head=None,
roi_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
backbone: dict,
neck: Optional[dict] = None,
rpn_head: Optional[dict] = None,
roi_head: Optional[dict] = None,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None,
data_preprocessor: Optional[dict] = None,
**kwargs) -> None:
super(H3DNet, self).__init__(
backbone=backbone,
neck=neck,
......@@ -29,148 +48,110 @@ class H3DNet(TwoStage3DDetector):
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
init_cfg=init_cfg)
def forward_train(self,
points,
img_metas,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
gt_bboxes_ignore=None):
"""Forward of training.
init_cfg=init_cfg,
data_preprocessor=data_preprocessor,
**kwargs)
def extract_feat(self, batch_inputs_dict: dict) -> None:
"""Directly extract features from the backbone+neck.
Args:
points (list[torch.Tensor]): Points of each batch.
img_metas (list): Image metas.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
pts_semantic_mask (list[torch.Tensor]): point-wise semantic
label of each batch.
pts_instance_mask (list[torch.Tensor]): point-wise instance
label of each batch.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding.
batch_inputs_dict (dict): The model input dict which include
'points'.
- points (list[torch.Tensor]): Point cloud of each sample.
Returns:
dict: Losses.
dict: Dict of feature.
"""
points_cat = torch.stack(points)
stack_points = torch.stack(batch_inputs_dict['points'])
x = self.backbone(stack_points)
if self.with_neck:
x = self.neck(x)
return x
def loss(self, batch_inputs_dict: Dict[str, Union[List, Tensor]],
batch_data_samples: List[Det3DDataSample], **kwargs) -> dict:
"""
Args:
batch_inputs_dict (dict): The model input dict which include
'points' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
feats_dict = self.extract_feat(batch_inputs_dict)
feats_dict = self.extract_feat(points_cat)
feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]]
feats_dict['fp_features'] = [feats_dict['hd_feature']]
feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]]
losses = dict()
if self.with_rpn:
rpn_outs = self.rpn_head(feats_dict, self.train_cfg.rpn.sample_mod)
feats_dict.update(rpn_outs)
rpn_loss_inputs = (points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, img_metas)
rpn_losses = self.rpn_head.loss(
rpn_outs,
*rpn_loss_inputs,
gt_bboxes_ignore=gt_bboxes_ignore,
ret_target=True)
feats_dict['targets'] = rpn_losses.pop('targets')
losses.update(rpn_losses)
# Generate rpn proposals
proposal_cfg = self.train_cfg.get('rpn_proposal',
self.test_cfg.rpn)
proposal_inputs = (points, rpn_outs, img_metas)
proposal_list = self.rpn_head.get_bboxes(
*proposal_inputs, use_nms=proposal_cfg.use_nms)
feats_dict['proposal_list'] = proposal_list
# note, the feats_dict would be added new key & value in rpn_head
rpn_losses, rpn_proposals = self.rpn_head.loss_and_predict(
batch_inputs_dict['points'],
feats_dict,
batch_data_samples,
ret_target=True,
proposal_cfg=proposal_cfg)
feats_dict['targets'] = rpn_losses.pop('targets')
losses.update(rpn_losses)
feats_dict['rpn_proposals'] = rpn_proposals
else:
raise NotImplementedError
roi_losses = self.roi_head.forward_train(feats_dict, img_metas, points,
gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask,
pts_instance_mask,
gt_bboxes_ignore)
roi_losses = self.roi_head.loss(batch_inputs_dict['points'],
feats_dict, batch_data_samples,
**kwargs)
losses.update(roi_losses)
return losses
def simple_test(self, points, img_metas, imgs=None, rescale=False):
"""Forward of testing.
def predict(
self, batch_input_dict: Dict,
batch_data_samples: List[Det3DDataSample]
) -> List[Det3DDataSample]:
"""Get model predictions.
Args:
points (list[torch.Tensor]): Points of each sample.
img_metas (list): Image metas.
rescale (bool): Whether to rescale results.
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each sample and
corresponding annotations.
Returns:
list: Predicted 3d boxes.
"""
points_cat = torch.stack(points)
feats_dict = self.extract_feat(points_cat)
feats_dict = self.extract_feat(batch_input_dict)
feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]]
feats_dict['fp_features'] = [feats_dict['hd_feature']]
feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]]
if self.with_rpn:
proposal_cfg = self.test_cfg.rpn
rpn_outs = self.rpn_head(feats_dict, proposal_cfg.sample_mod)
feats_dict.update(rpn_outs)
# Generate rpn proposals
proposal_list = self.rpn_head.get_bboxes(
points, rpn_outs, img_metas, use_nms=proposal_cfg.use_nms)
feats_dict['proposal_list'] = proposal_list
rpn_proposals = self.rpn_head.predict(
batch_input_dict['points'],
feats_dict,
batch_data_samples,
use_nms=proposal_cfg.use_nms)
feats_dict['rpn_proposals'] = rpn_proposals
else:
raise NotImplementedError
return self.roi_head.simple_test(
feats_dict, img_metas, points_cat, rescale=rescale)
def aug_test(self, points, img_metas, imgs=None, rescale=False):
"""Test with augmentation."""
points_cat = [torch.stack(pts) for pts in points]
feats_dict = self.extract_feats(points_cat, img_metas)
for feat_dict in feats_dict:
feat_dict['fp_xyz'] = [feat_dict['fp_xyz_net0'][-1]]
feat_dict['fp_features'] = [feat_dict['hd_feature']]
feat_dict['fp_indices'] = [feat_dict['fp_indices_net0'][-1]]
# only support aug_test for one sample
aug_bboxes = []
for feat_dict, pts_cat, img_meta in zip(feats_dict, points_cat,
img_metas):
if self.with_rpn:
proposal_cfg = self.test_cfg.rpn
rpn_outs = self.rpn_head(feat_dict, proposal_cfg.sample_mod)
feat_dict.update(rpn_outs)
# Generate rpn proposals
proposal_list = self.rpn_head.get_bboxes(
points, rpn_outs, img_metas, use_nms=proposal_cfg.use_nms)
feat_dict['proposal_list'] = proposal_list
else:
raise NotImplementedError
bbox_results = self.roi_head.simple_test(
feat_dict,
self.test_cfg.rcnn.sample_mod,
img_meta,
pts_cat,
rescale=rescale)
aug_bboxes.append(bbox_results)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.bbox_head.test_cfg)
return [merged_bboxes]
def extract_feats(self, points, img_metas):
"""Extract features of multiple samples."""
return [
self.extract_feat(pts, img_meta)
for pts, img_meta in zip(points, img_metas)
]
results_list = self.roi_head.predict(
batch_input_dict['points'],
feats_dict,
batch_data_samples,
suffix='_optimized')
return self.convert_to_datasample(results_list)
......@@ -56,12 +56,12 @@ class PointRCNN(TwoStage3DDetector):
x = self.neck(x)
return x
def forward_train(self, points, img_metas, gt_bboxes_3d, gt_labels_3d):
def forward_train(self, points, input_metas, gt_bboxes_3d, gt_labels_3d):
"""Forward of training.
Args:
points (list[torch.Tensor]): Points of each batch.
img_metas (list[dict]): Meta information of each sample.
input_metas (list[dict]): Meta information of each sample.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
......@@ -69,8 +69,8 @@ class PointRCNN(TwoStage3DDetector):
dict: Losses.
"""
losses = dict()
points_cat = torch.stack(points)
x = self.extract_feat(points_cat)
stack_points = torch.stack(points)
x = self.extract_feat(stack_points)
# features for rcnn
backbone_feats = x['fp_features'].clone()
......@@ -85,11 +85,11 @@ class PointRCNN(TwoStage3DDetector):
points=points,
gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=gt_labels_3d,
img_metas=img_metas)
input_metas=input_metas)
losses.update(rpn_loss)
bbox_list = self.rpn_head.get_bboxes(points_cat, bbox_preds, cls_preds,
img_metas)
bbox_list = self.rpn_head.get_bboxes(stack_points, bbox_preds,
cls_preds, input_metas)
proposal_list = [
dict(
boxes_3d=bboxes,
......@@ -100,7 +100,7 @@ class PointRCNN(TwoStage3DDetector):
]
rcnn_feats.update({'points_cls_preds': cls_preds})
roi_losses = self.roi_head.forward_train(rcnn_feats, img_metas,
roi_losses = self.roi_head.forward_train(rcnn_feats, input_metas,
proposal_list, gt_bboxes_3d,
gt_labels_3d)
losses.update(roi_losses)
......@@ -121,9 +121,9 @@ class PointRCNN(TwoStage3DDetector):
Returns:
list: Predicted 3d boxes.
"""
points_cat = torch.stack(points)
stack_points = torch.stack(points)
x = self.extract_feat(points_cat)
x = self.extract_feat(stack_points)
# features for rcnn
backbone_feats = x['fp_features'].clone()
backbone_xyz = x['fp_xyz'].clone()
......@@ -132,7 +132,7 @@ class PointRCNN(TwoStage3DDetector):
rcnn_feats.update({'points_cls_preds': cls_preds})
bbox_list = self.rpn_head.get_bboxes(
points_cat, bbox_preds, cls_preds, img_metas, rescale=rescale)
stack_points, bbox_preds, cls_preds, img_metas, rescale=rescale)
proposal_list = [
dict(
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from mmengine import InstanceData
from torch import Tensor
from torch import nn as nn
from torch.nn import functional as F
from mmdet3d.core import build_bbox_coder
from mmdet3d.core import BaseInstance3DBoxes, Det3DDataSample
from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import build_loss
from mmdet3d.models.losses import chamfer_distance
from mmdet3d.ops import build_sa_module
from mmdet3d.registry import MODELS
from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet.core import multi_apply
......@@ -25,66 +28,73 @@ class H3DBboxHead(BaseModule):
line_matching_cfg (dict): Config for line primitive matching.
bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
decoding boxes.
train_cfg (dict): Config for training.
test_cfg (dict): Config for testing.
train_cfg (dict): Config for training. Defaults to None.
test_cfg (dict): Config for testing. Defaults to None.
gt_per_seed (int): Number of ground truth votes generated
from each seed point.
from each seed point. Defaults to 1.
num_proposal (int): Number of proposal votes generated.
feat_channels (tuple[int]): Convolution channels of
prediction layer.
Defaults to 256.
primitive_feat_refine_streams (int): The number of mlps to
refine primitive feature.
refine primitive feature. Defaults to 2.
primitive_refine_channels (tuple[int]): Convolution channels of
prediction layer.
upper_thresh (float): Threshold for line matching.
prediction layer. Defaults to [128, 128, 128].
upper_thresh (float): Threshold for line matching. Defaults to 100.
surface_thresh (float): Threshold for surface matching.
line_thresh (float): Threshold for line matching.
Defaults to 0.5.
line_thresh (float): Threshold for line matching. Defaults to 0.5.
conv_cfg (dict): Config of convolution in prediction layer.
norm_cfg (dict): Config of BN in prediction layer.
objectness_loss (dict): Config of objectness loss.
center_loss (dict): Config of center loss.
Defaults to None.
norm_cfg (dict): Config of BN in prediction layer. Defaults to None.
objectness_loss (dict): Config of objectness loss. Defaults to None.
center_loss (dict): Config of center loss. Defaults to None.
dir_class_loss (dict): Config of direction classification loss.
Defaults to None.
dir_res_loss (dict): Config of direction residual regression loss.
Defaults to None.
size_class_loss (dict): Config of size classification loss.
Defaults to None.
size_res_loss (dict): Config of size residual regression loss.
Defaults to None.
semantic_loss (dict): Config of point-wise semantic segmentation loss.
Defaults to None.
cues_objectness_loss (dict): Config of cues objectness loss.
Defaults to None.
cues_semantic_loss (dict): Config of cues semantic loss.
Defaults to None.
proposal_objectness_loss (dict): Config of proposal objectness
loss.
loss. Defaults to None.
primitive_center_loss (dict): Config of primitive center regression
loss.
loss. Defaults to None.
"""
def __init__(self,
num_classes,
suface_matching_cfg,
line_matching_cfg,
bbox_coder,
train_cfg=None,
test_cfg=None,
gt_per_seed=1,
num_proposal=256,
feat_channels=(128, 128),
primitive_feat_refine_streams=2,
primitive_refine_channels=[128, 128, 128],
upper_thresh=100.0,
surface_thresh=0.5,
line_thresh=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=None,
center_loss=None,
dir_class_loss=None,
dir_res_loss=None,
size_class_loss=None,
size_res_loss=None,
semantic_loss=None,
cues_objectness_loss=None,
cues_semantic_loss=None,
proposal_objectness_loss=None,
primitive_center_loss=None,
init_cfg=None):
num_classes: int,
suface_matching_cfg: dict,
line_matching_cfg: dict,
bbox_coder: dict,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
gt_per_seed: int = 1,
num_proposal: int = 256,
primitive_feat_refine_streams: int = 2,
primitive_refine_channels: List[int] = [128, 128, 128],
upper_thresh: float = 100.0,
surface_thresh: float = 0.5,
line_thresh: float = 0.5,
conv_cfg: dict = dict(type='Conv1d'),
norm_cfg: dict = dict(type='BN1d'),
objectness_loss: Optional[dict] = None,
center_loss: Optional[dict] = None,
dir_class_loss: Optional[dict] = None,
dir_res_loss: Optional[dict] = None,
size_class_loss: Optional[dict] = None,
size_res_loss: Optional[dict] = None,
semantic_loss: Optional[dict] = None,
cues_objectness_loss: Optional[dict] = None,
cues_semantic_loss: Optional[dict] = None,
proposal_objectness_loss: Optional[dict] = None,
primitive_center_loss: Optional[dict] = None,
init_cfg: dict = None):
super(H3DBboxHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.train_cfg = train_cfg
......@@ -96,22 +106,22 @@ class H3DBboxHead(BaseModule):
self.surface_thresh = surface_thresh
self.line_thresh = line_thresh
self.objectness_loss = build_loss(objectness_loss)
self.center_loss = build_loss(center_loss)
self.dir_class_loss = build_loss(dir_class_loss)
self.dir_res_loss = build_loss(dir_res_loss)
self.size_class_loss = build_loss(size_class_loss)
self.size_res_loss = build_loss(size_res_loss)
self.semantic_loss = build_loss(semantic_loss)
self.loss_objectness = MODELS.build(objectness_loss)
self.loss_center = MODELS.build(center_loss)
self.loss_dir_class = MODELS.build(dir_class_loss)
self.loss_dir_res = MODELS.build(dir_res_loss)
self.loss_size_class = MODELS.build(size_class_loss)
self.loss_size_res = MODELS.build(size_res_loss)
self.loss_semantic = MODELS.build(semantic_loss)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.bbox_coder = TASK_UTILS.build(bbox_coder)
self.num_sizes = self.bbox_coder.num_sizes
self.num_dir_bins = self.bbox_coder.num_dir_bins
self.cues_objectness_loss = build_loss(cues_objectness_loss)
self.cues_semantic_loss = build_loss(cues_semantic_loss)
self.proposal_objectness_loss = build_loss(proposal_objectness_loss)
self.primitive_center_loss = build_loss(primitive_center_loss)
self.loss_cues_objectness = MODELS.build(cues_objectness_loss)
self.loss_cues_semantic = MODELS.build(cues_semantic_loss)
self.loss_proposal_objectness = MODELS.build(proposal_objectness_loss)
self.loss_primitive_center = MODELS.build(primitive_center_loss)
assert suface_matching_cfg['mlp_channels'][-1] == \
line_matching_cfg['mlp_channels'][-1]
......@@ -202,16 +212,14 @@ class H3DBboxHead(BaseModule):
bbox_coder['num_sizes'] * 4 + self.num_classes)
self.bbox_pred.append(nn.Conv1d(prev_channel, conv_out_channel, 1))
def forward(self, feats_dict, sample_mod):
def forward(self, feats_dict: dict):
"""Forward pass.
Args:
feats_dict (dict): Feature dict from backbone.
sample_mod (str): Sample mode for vote aggregation layer.
valid modes are "vote", "seed" and "random".
Returns:
dict: Predictions of vote head.
dict: Predictions of head.
"""
ret_dict = {}
aggregated_points = feats_dict['aggregated_points']
......@@ -236,7 +244,7 @@ class H3DBboxHead(BaseModule):
dim=1)
# Extract the surface and line centers of rpn proposals
rpn_proposals = feats_dict['proposal_list']
rpn_proposals = feats_dict['rpn_proposals']
rpn_proposals_bbox = DepthInstance3DBoxes(
rpn_proposals.reshape(-1, 7).clone(),
box_dim=rpn_proposals.shape[-1],
......@@ -310,36 +318,29 @@ class H3DBboxHead(BaseModule):
ret_dict[key + '_optimized'] = refine_decode_res[key]
return ret_dict
def loss(self,
bbox_preds,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
img_metas=None,
rpn_targets=None,
gt_bboxes_ignore=None):
"""Compute loss.
def loss(
self,
points: List[Tensor],
feats_dict: dict,
rpn_targets: Tuple = None,
batch_data_samples: List[Det3DDataSample] = None,
):
"""
Args:
bbox_preds (dict): Predictions from forward of h3d bbox head.
points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
bboxes of each sample.
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
pts_semantic_mask (list[torch.Tensor]): Point-wise
semantic mask.
pts_instance_mask (list[torch.Tensor]): Point-wise
instance mask.
img_metas (list[dict]): Contain pcd and img's meta info.
rpn_targets (Tuple) : Targets generated by rpn head.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding.
points (list[tensor]): Points cloud of multiple samples.
feats_dict (dict): Predictions from backbone or FPN.
rpn_targets (Tuple, Optional): The target of sample from RPN.
Defaults to None.
batch_data_samples (list[:obj:`Det3DDataSample`], Optional):
Each item contains the meta information of each sample
and corresponding annotations. Defaults to None.
Returns:
dict: Losses of H3dnet.
dict: A dictionary of loss components.
"""
preds = self(feats_dict)
feats_dict.update(preds)
(vote_targets, vote_target_masks, size_class_targets, size_res_targets,
dir_class_targets, dir_res_targets, center_targets, _, mask_targets,
valid_gt_masks, objectness_targets, objectness_weights,
......@@ -349,7 +350,7 @@ class H3DBboxHead(BaseModule):
# calculate refined proposal loss
refined_proposal_loss = self.get_proposal_stage_loss(
bbox_preds,
feats_dict,
size_class_targets,
size_res_targets,
dir_class_targets,
......@@ -364,36 +365,60 @@ class H3DBboxHead(BaseModule):
for key in refined_proposal_loss.keys():
losses[key + '_optimized'] = refined_proposal_loss[key]
batch_gt_instance_3d = []
batch_input_metas = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
batch_gt_instance_3d.append(data_sample.gt_instances_3d)
temp_loss = self.loss_by_feat(points, feats_dict, batch_gt_instance_3d)
losses.update(temp_loss)
return losses
def loss_by_feat(self, points: List[torch.Tensor], feats_dict: dict,
batch_gt_instances_3d: List[InstanceData],
**kwargs) -> dict:
"""Compute loss.
Args:
points (list[torch.Tensor]): Input points.
feats_dict (dict): Predictions from forward of vote head.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes`` and ``labels``
attributes.
Returns:
dict: Losses of H3DNet.
"""
bbox3d_optimized = self.bbox_coder.decode(
bbox_preds, suffix='_optimized')
feats_dict, suffix='_optimized')
targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask,
bbox_preds)
targets = self.get_targets(points, feats_dict, batch_gt_instances_3d)
(cues_objectness_label, cues_sem_label, proposal_objectness_label,
cues_mask, cues_match_mask, proposal_objectness_mask,
cues_matching_label, obj_surface_line_center) = targets
# match scores for each geometric primitive
objectness_scores = bbox_preds['matching_score']
objectness_scores = feats_dict['matching_score']
# match scores for the semantics of primitives
objectness_scores_sem = bbox_preds['semantic_matching_score']
objectness_scores_sem = feats_dict['semantic_matching_score']
primitive_objectness_loss = self.cues_objectness_loss(
primitive_objectness_loss = self.loss_cues_objectness(
objectness_scores.transpose(2, 1),
cues_objectness_label,
weight=cues_mask,
avg_factor=cues_mask.sum() + 1e-6)
primitive_sem_loss = self.cues_semantic_loss(
primitive_sem_loss = self.loss_cues_semantic(
objectness_scores_sem.transpose(2, 1),
cues_sem_label,
weight=cues_mask,
avg_factor=cues_mask.sum() + 1e-6)
objectness_scores = bbox_preds['obj_scores_optimized']
objectness_loss_refine = self.proposal_objectness_loss(
objectness_scores = feats_dict['obj_scores_optimized']
objectness_loss_refine = self.loss_proposal_objectness(
objectness_scores.transpose(2, 1), proposal_objectness_label)
primitive_matching_loss = (objectness_loss_refine *
cues_match_mask).sum() / (
......@@ -419,7 +444,7 @@ class H3DBboxHead(BaseModule):
pred_surface_line_center = torch.cat(
(pred_obj_surface_center, pred_obj_line_center), 1)
square_dist = self.primitive_center_loss(pred_surface_line_center,
square_dist = self.loss_primitive_center(pred_surface_line_center,
obj_surface_line_center)
match_dist = torch.sqrt(square_dist.sum(dim=-1) + 1e-6)
......@@ -434,58 +459,102 @@ class H3DBboxHead(BaseModule):
primitive_sem_matching_loss=primitive_sem_matching_loss,
primitive_centroid_reg_loss=primitive_centroid_reg_loss)
losses.update(refined_loss)
return refined_loss
return losses
def predict(self,
points: List[torch.Tensor],
feats_dict: Dict[str, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
suffix='_optimized',
**kwargs) -> List[InstanceData]:
"""
Args:
points (list[tensor]): Point clouds of multiple samples.
feats_dict (dict): Features from FPN or backbone..
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes meta information of data.
suffix (str): suffix for tensor in feats_dict.
Defaults to '_optimized'.
def get_bboxes(self,
points,
bbox_preds,
input_metas,
rescale=False,
suffix=''):
Returns:
list[:obj:`InstanceData`]: List of processed predictions. Each
InstanceData contains 3d Bounding boxes and corresponding
scores and labels.
"""
preds_dict = self(feats_dict)
# `preds_dict` can be used in H3DNET
feats_dict.update(preds_dict)
batch_size = len(batch_data_samples)
batch_input_metas = []
for batch_index in range(batch_size):
metainfo = batch_data_samples[batch_index].metainfo
batch_input_metas.append(metainfo)
results_list = self.predict_by_feat(
points, feats_dict, batch_input_metas, suffix=suffix, **kwargs)
return results_list
def predict_by_feat(self,
points: List[torch.Tensor],
feats_dict: dict,
batch_input_metas: List[dict],
suffix='_optimized',
**kwargs) -> List[InstanceData]:
"""Generate bboxes from vote head predictions.
Args:
points (torch.Tensor): Input points.
bbox_preds (dict): Predictions from vote head.
input_metas (list[dict]): Point cloud and image's meta info.
rescale (bool): Whether to rescale bboxes.
points (List[torch.Tensor]): Input points of multiple samples.
feats_dict (dict): Predictions from previous components.
batch_input_metas (list[dict]): Each item
contains the meta information of each sample.
suffix (str): suffix for tensor in feats_dict.
Defaults to '_optimized'.
Returns:
list[tuple[torch.Tensor]]: Bounding boxes, scores and labels.
list[:obj:`InstanceData`]: Return list of processed
predictions. Each InstanceData cantains
3d Bounding boxes and corresponding scores and labels.
"""
# decode boxes
obj_scores = F.softmax(
bbox_preds['obj_scores' + suffix], dim=-1)[..., -1]
feats_dict['obj_scores' + suffix], dim=-1)[..., -1]
sem_scores = F.softmax(bbox_preds['sem_scores'], dim=-1)
sem_scores = F.softmax(feats_dict['sem_scores'], dim=-1)
prediction_collection = {}
prediction_collection['center'] = bbox_preds['center' + suffix]
prediction_collection['dir_class'] = bbox_preds['dir_class']
prediction_collection['dir_res'] = bbox_preds['dir_res' + suffix]
prediction_collection['size_class'] = bbox_preds['size_class']
prediction_collection['size_res'] = bbox_preds['size_res' + suffix]
prediction_collection['center'] = feats_dict['center' + suffix]
prediction_collection['dir_class'] = feats_dict['dir_class']
prediction_collection['dir_res'] = feats_dict['dir_res' + suffix]
prediction_collection['size_class'] = feats_dict['size_class']
prediction_collection['size_res'] = feats_dict['size_res' + suffix]
bbox3d = self.bbox_coder.decode(prediction_collection)
batch_size = bbox3d.shape[0]
results = list()
results_list = list()
points = torch.stack(points)
for b in range(batch_size):
temp_results = InstanceData()
bbox_selected, score_selected, labels = self.multiclass_nms_single(
obj_scores[b], sem_scores[b], bbox3d[b], points[b, ..., :3],
input_metas[b])
bbox = input_metas[b]['box_type_3d'](
batch_input_metas[b])
bbox = batch_input_metas[b]['box_type_3d'](
bbox_selected,
box_dim=bbox_selected.shape[-1],
with_yaw=self.bbox_coder.with_rot)
results.append((bbox, score_selected, labels))
return results
temp_results.bboxes_3d = bbox
temp_results.scores_3d = score_selected
temp_results.labels_3d = labels
results_list.append(temp_results)
return results_list
def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points,
input_meta):
def multiclass_nms_single(self, obj_scores: Tensor, sem_scores: Tensor,
bbox: Tensor, points: Tensor,
input_meta: dict) -> Tuple:
"""Multi-class nms in single batch.
Args:
......@@ -586,13 +655,13 @@ class H3DBboxHead(BaseModule):
dict: Losses of aggregation module.
"""
# calculate objectness loss
objectness_loss = self.objectness_loss(
objectness_loss = self.loss_objectness(
bbox_preds['obj_scores' + suffix].transpose(2, 1),
objectness_targets,
weight=objectness_weights)
# calculate center loss
source2target_loss, target2source_loss = self.center_loss(
source2target_loss, target2source_loss = self.loss_center(
bbox_preds['center' + suffix],
center_targets,
src_weight=box_loss_weights,
......@@ -600,7 +669,7 @@ class H3DBboxHead(BaseModule):
center_loss = source2target_loss + target2source_loss
# calculate direction class loss
dir_class_loss = self.dir_class_loss(
dir_class_loss = self.loss_dir_class(
bbox_preds['dir_class' + suffix].transpose(2, 1),
dir_class_targets,
weight=box_loss_weights)
......@@ -612,11 +681,11 @@ class H3DBboxHead(BaseModule):
heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1)
dir_res_norm = (bbox_preds['dir_res_norm' + suffix] *
heading_label_one_hot).sum(dim=-1)
dir_res_loss = self.dir_res_loss(
dir_res_loss = self.loss_dir_res(
dir_res_norm, dir_res_targets, weight=box_loss_weights)
# calculate size class loss
size_class_loss = self.size_class_loss(
size_class_loss = self.loss_size_class(
bbox_preds['size_class' + suffix].transpose(2, 1),
size_class_targets,
weight=box_loss_weights)
......@@ -631,13 +700,13 @@ class H3DBboxHead(BaseModule):
one_hot_size_targets_expand).sum(dim=2)
box_loss_weights_expand = box_loss_weights.unsqueeze(-1).repeat(
1, 1, 3)
size_res_loss = self.size_res_loss(
size_res_loss = self.loss_size_res(
size_residual_norm,
size_res_targets,
weight=box_loss_weights_expand)
# calculate semantic loss
semantic_loss = self.semantic_loss(
semantic_loss = self.loss_semantic(
bbox_preds['sem_scores' + suffix].transpose(2, 1),
mask_targets,
weight=box_loss_weights)
......@@ -653,91 +722,93 @@ class H3DBboxHead(BaseModule):
return losses
def get_targets(self,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
bbox_preds=None):
"""Generate targets of proposal module.
def get_targets(
self,
points,
feats_dict: Optional[dict] = None,
batch_gt_instances_3d: Optional[List[InstanceData]] = None,
):
"""Generate targets of vote head.
Args:
points (list[torch.Tensor]): Points of each batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): Labels of each batch.
pts_semantic_mask (list[torch.Tensor]): Point-wise semantic
label of each batch.
pts_instance_mask (list[torch.Tensor]): Point-wise instance
label of each batch.
bbox_preds (torch.Tensor): Bounding box predictions of vote head.
feats_dict (dict, optional): Predictions of previous
components. Defaults to None.
batch_gt_instances_3d (list[:obj:`InstanceData`], optional):
Batch of gt_instances. It usually includes
``bboxes_3d`` and ``labels_3d`` attributes.
Returns:
tuple[torch.Tensor]: Targets of proposal module.
tuple[torch.Tensor]: Targets of vote head.
"""
# find empty example
valid_gt_masks = list()
gt_num = list()
for index in range(len(gt_labels_3d)):
if len(gt_labels_3d[index]) == 0:
fake_box = gt_bboxes_3d[index].tensor.new_zeros(
1, gt_bboxes_3d[index].tensor.shape[-1])
gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box)
gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1)
valid_gt_masks.append(gt_labels_3d[index].new_zeros(1))
batch_gt_labels_3d = [
gt_instances_3d.labels_3d
for gt_instances_3d in batch_gt_instances_3d
]
batch_gt_bboxes_3d = [
gt_instances_3d.bboxes_3d
for gt_instances_3d in batch_gt_instances_3d
]
for index in range(len(batch_gt_labels_3d)):
if len(batch_gt_labels_3d[index]) == 0:
fake_box = batch_gt_bboxes_3d[index].tensor.new_zeros(
1, batch_gt_bboxes_3d[index].tensor.shape[-1])
batch_gt_bboxes_3d[index] = batch_gt_bboxes_3d[index].new_box(
fake_box)
batch_gt_labels_3d[index] = batch_gt_labels_3d[
index].new_zeros(1)
valid_gt_masks.append(batch_gt_labels_3d[index].new_zeros(1))
gt_num.append(1)
else:
valid_gt_masks.append(gt_labels_3d[index].new_ones(
gt_labels_3d[index].shape))
gt_num.append(gt_labels_3d[index].shape[0])
if pts_semantic_mask is None:
pts_semantic_mask = [None for i in range(len(gt_labels_3d))]
pts_instance_mask = [None for i in range(len(gt_labels_3d))]
valid_gt_masks.append(batch_gt_labels_3d[index].new_ones(
batch_gt_labels_3d[index].shape))
gt_num.append(batch_gt_labels_3d[index].shape[0])
aggregated_points = [
bbox_preds['aggregated_points'][i]
for i in range(len(gt_labels_3d))
feats_dict['aggregated_points'][i]
for i in range(len(batch_gt_labels_3d))
]
surface_center_pred = [
bbox_preds['surface_center_pred'][i]
for i in range(len(gt_labels_3d))
feats_dict['surface_center_pred'][i]
for i in range(len(batch_gt_labels_3d))
]
line_center_pred = [
bbox_preds['pred_line_center'][i]
for i in range(len(gt_labels_3d))
feats_dict['pred_line_center'][i]
for i in range(len(batch_gt_labels_3d))
]
surface_center_object = [
bbox_preds['surface_center_object'][i]
for i in range(len(gt_labels_3d))
feats_dict['surface_center_object'][i]
for i in range(len(batch_gt_labels_3d))
]
line_center_object = [
bbox_preds['line_center_object'][i]
for i in range(len(gt_labels_3d))
feats_dict['line_center_object'][i]
for i in range(len(batch_gt_labels_3d))
]
surface_sem_pred = [
bbox_preds['surface_sem_pred'][i]
for i in range(len(gt_labels_3d))
feats_dict['surface_sem_pred'][i]
for i in range(len(batch_gt_labels_3d))
]
line_sem_pred = [
bbox_preds['sem_cls_scores_line'][i]
for i in range(len(gt_labels_3d))
feats_dict['sem_cls_scores_line'][i]
for i in range(len(batch_gt_labels_3d))
]
(cues_objectness_label, cues_sem_label, proposal_objectness_label,
cues_mask, cues_match_mask, proposal_objectness_mask,
cues_matching_label, obj_surface_line_center) = multi_apply(
self.get_targets_single, points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, aggregated_points,
surface_center_pred, line_center_pred, surface_center_object,
line_center_object, surface_sem_pred, line_sem_pred)
self._get_targets_single, points, batch_gt_bboxes_3d,
batch_gt_labels_3d, aggregated_points, surface_center_pred,
line_center_pred, surface_center_object, line_center_object,
surface_sem_pred, line_sem_pred)
cues_objectness_label = torch.stack(cues_objectness_label)
cues_sem_label = torch.stack(cues_sem_label)
......@@ -753,19 +824,17 @@ class H3DBboxHead(BaseModule):
proposal_objectness_mask, cues_matching_label,
obj_surface_line_center)
def get_targets_single(self,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
aggregated_points=None,
pred_surface_center=None,
pred_line_center=None,
pred_obj_surface_center=None,
pred_obj_line_center=None,
pred_surface_sem=None,
pred_line_sem=None):
def _get_targets_single(self,
points: Tensor,
gt_bboxes_3d: BaseInstance3DBoxes,
gt_labels_3d: Tensor,
aggregated_points: Optional[Tensor] = None,
pred_surface_center: Optional[Tensor] = None,
pred_line_center: Optional[Tensor] = None,
pred_obj_surface_center: Optional[Tensor] = None,
pred_obj_line_center: Optional[Tensor] = None,
pred_surface_sem: Optional[Tensor] = None,
pred_line_sem: Optional[Tensor] = None):
"""Generate targets for primitive cues for single batch.
Args:
......@@ -773,10 +842,6 @@ class H3DBboxHead(BaseModule):
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth
boxes of each batch.
gt_labels_3d (torch.Tensor): Labels of each batch.
pts_semantic_mask (torch.Tensor): Point-wise semantic
label of each batch.
pts_instance_mask (torch.Tensor): Point-wise instance
label of each batch.
aggregated_points (torch.Tensor): Aggregated points from
vote aggregation layer.
pred_surface_center (torch.Tensor): Prediction of surface center.
......@@ -847,12 +912,10 @@ class H3DBboxHead(BaseModule):
euclidean_dist_line = torch.sqrt(dist_line.squeeze(0) + 1e-6)
objectness_label_surface = euclidean_dist_line.new_zeros(
num_proposals * 6, dtype=torch.long)
objectness_mask_surface = euclidean_dist_line.new_zeros(num_proposals *
6)
objectness_label_line = euclidean_dist_line.new_zeros(
num_proposals * 12, dtype=torch.long)
objectness_mask_line = euclidean_dist_line.new_zeros(num_proposals *
12)
objectness_label_surface_sem = euclidean_dist_line.new_zeros(
num_proposals * 6, dtype=torch.long)
objectness_label_line_sem = euclidean_dist_line.new_zeros(
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet3d.core.bbox import bbox3d2result
from typing import Dict, List
from mmengine import InstanceData
from torch import Tensor
from mmdet3d.registry import MODELS
from ...core import Det3DDataSample
from .base_3droi_head import Base3DRoIHead
......@@ -16,17 +21,15 @@ class H3DRoIHead(Base3DRoIHead):
"""
def __init__(self,
primitive_list,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
primitive_list: List[dict],
bbox_head: dict = None,
train_cfg: dict = None,
test_cfg: dict = None,
init_cfg: dict = None):
super(H3DRoIHead, self).__init__(
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
init_cfg=init_cfg)
# Primitive module
assert len(primitive_list) == 3
......@@ -39,8 +42,14 @@ class H3DRoIHead(Base3DRoIHead):
one."""
pass
def init_bbox_head(self, bbox_head):
"""Initialize box head."""
def init_bbox_head(self, dummy_args, bbox_head):
"""Initialize box head.
Args:
dummy_args (optional): Just to compatible with
the interface in base class
bbox_head (dict): Config for bbox head.
"""
bbox_head['train_cfg'] = self.train_cfg
bbox_head['test_cfg'] = self.test_cfg
self.bbox_head = MODELS.build(bbox_head)
......@@ -49,111 +58,73 @@ class H3DRoIHead(Base3DRoIHead):
"""Initialize assigner and sampler."""
pass
def forward_train(self,
feats_dict,
img_metas,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask,
pts_instance_mask,
gt_bboxes_ignore=None):
def loss(self, points: List[Tensor], feats_dict: dict,
batch_data_samples: List[Det3DDataSample], **kwargs):
"""Training forward function of PartAggregationROIHead.
Args:
feats_dict (dict): Contains features from the first stage.
img_metas (list[dict]): Contain pcd and img's meta info.
points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
bboxes of each sample.
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
pts_semantic_mask (list[torch.Tensor]): Point-wise
semantic mask.
pts_instance_mask (list[torch.Tensor]): Point-wise
instance mask.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding boxes to ignore.
points (list[torch.Tensor]): Point cloud of each sample.
feats_dict (dict): Dict of feature.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`.
Returns:
dict: losses from each head.
"""
losses = dict()
sample_mod = self.train_cfg.sample_mod
assert sample_mod in ['vote', 'seed', 'random']
result_z = self.primitive_z(feats_dict, sample_mod)
feats_dict.update(result_z)
result_xy = self.primitive_xy(feats_dict, sample_mod)
feats_dict.update(result_xy)
result_line = self.primitive_line(feats_dict, sample_mod)
feats_dict.update(result_line)
primitive_loss_inputs = (feats_dict, points, gt_bboxes_3d,
gt_labels_3d, pts_semantic_mask,
pts_instance_mask, img_metas,
gt_bboxes_ignore)
primitive_loss_inputs = (points, feats_dict, batch_data_samples)
# note the feats_dict would be added new key and value in each head.
loss_z = self.primitive_z.loss(*primitive_loss_inputs)
losses.update(loss_z)
loss_xy = self.primitive_xy.loss(*primitive_loss_inputs)
losses.update(loss_xy)
loss_line = self.primitive_line.loss(*primitive_loss_inputs)
losses.update(loss_z)
losses.update(loss_xy)
losses.update(loss_line)
targets = feats_dict.pop('targets')
bbox_results = self.bbox_head(feats_dict, sample_mod)
feats_dict.update(bbox_results)
bbox_loss = self.bbox_head.loss(feats_dict, points, gt_bboxes_3d,
gt_labels_3d, pts_semantic_mask,
pts_instance_mask, img_metas, targets,
gt_bboxes_ignore)
bbox_loss = self.bbox_head.loss(
points,
feats_dict,
rpn_targets=targets,
batch_data_samples=batch_data_samples)
losses.update(bbox_loss)
return losses
def simple_test(self, feats_dict, img_metas, points, rescale=False):
"""Simple testing forward function of PartAggregationROIHead.
Note:
This function assumes that the batch size is 1
def predict(self,
points: List[Tensor],
feats_dict: Dict[str, Tensor],
batch_data_samples: List[Det3DDataSample],
suffix='_optimized',
**kwargs) -> List[InstanceData]:
"""
Args:
feats_dict (dict): Contains features from the first stage.
img_metas (list[dict]): Contain pcd and img's meta info.
points (torch.Tensor): Input points.
rescale (bool): Whether to rescale results.
points (list[tensor]): Point clouds of multiple samples.
feats_dict (dict): Features from FPN or backbone..
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes meta information of data.
Returns:
dict: Bbox results of one frame.
list[:obj:`InstanceData`]: List of processed predictions. Each
InstanceData contains 3d Bounding boxes and corresponding
scores and labels.
"""
sample_mod = self.test_cfg.sample_mod
assert sample_mod in ['vote', 'seed', 'random']
result_z = self.primitive_z(feats_dict, sample_mod)
result_z = self.primitive_z(feats_dict)
feats_dict.update(result_z)
result_xy = self.primitive_xy(feats_dict, sample_mod)
result_xy = self.primitive_xy(feats_dict)
feats_dict.update(result_xy)
result_line = self.primitive_line(feats_dict, sample_mod)
result_line = self.primitive_line(feats_dict)
feats_dict.update(result_line)
bbox_preds = self.bbox_head(feats_dict, sample_mod)
bbox_preds = self.bbox_head(feats_dict)
feats_dict.update(bbox_preds)
bbox_list = self.bbox_head.get_bboxes(
points,
feats_dict,
img_metas,
rescale=rescale,
suffix='_optimized')
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results
results_list = self.bbox_head.predict(
points, feats_dict, batch_data_samples, suffix=suffix)
return results_list
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional
import torch
from mmcv.cnn import ConvModule
from mmcv.ops import furthest_point_sample
from mmcv.runner import BaseModule
from mmengine import InstanceData
from torch import nn as nn
from torch.nn import functional as F
from mmdet3d.models.builder import build_loss
from mmdet3d.core import Det3DDataSample
from mmdet3d.models.model_utils import VoteModule
from mmdet3d.ops import build_sa_module
from mmdet3d.registry import MODELS
......@@ -40,24 +43,25 @@ class PrimitiveHead(BaseModule):
"""
def __init__(self,
num_dims,
num_classes,
primitive_mode,
train_cfg=None,
test_cfg=None,
vote_module_cfg=None,
vote_aggregation_cfg=None,
feat_channels=(128, 128),
upper_thresh=100.0,
surface_thresh=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=None,
center_loss=None,
semantic_reg_loss=None,
semantic_cls_loss=None,
init_cfg=None):
num_dims: int,
num_classes: int,
primitive_mode: str,
train_cfg: dict = None,
test_cfg: dict = None,
vote_module_cfg: dict = None,
vote_aggregation_cfg: dict = None,
feat_channels: tuple = (128, 128),
upper_thresh: float = 100.0,
surface_thresh: float = 0.5,
conv_cfg: dict = dict(type='Conv1d'),
norm_cfg: dict = dict(type='BN1d'),
objectness_loss: dict = None,
center_loss: dict = None,
semantic_reg_loss: dict = None,
semantic_cls_loss: dict = None,
init_cfg: dict = None):
super(PrimitiveHead, self).__init__(init_cfg=init_cfg)
# bounding boxes centers, face centers and edge centers
assert primitive_mode in ['z', 'xy', 'line']
# The dimension of primitive semantic information.
self.num_dims = num_dims
......@@ -70,10 +74,10 @@ class PrimitiveHead(BaseModule):
self.upper_thresh = upper_thresh
self.surface_thresh = surface_thresh
self.objectness_loss = build_loss(objectness_loss)
self.center_loss = build_loss(center_loss)
self.semantic_reg_loss = build_loss(semantic_reg_loss)
self.semantic_cls_loss = build_loss(semantic_cls_loss)
self.loss_objectness = MODELS.build(objectness_loss)
self.loss_center = MODELS.build(center_loss)
self.loss_semantic_reg = MODELS.build(semantic_reg_loss)
self.loss_semantic_cls = MODELS.build(semantic_cls_loss)
assert vote_aggregation_cfg['mlp_channels'][0] == vote_module_cfg[
'in_channels']
......@@ -114,18 +118,26 @@ class PrimitiveHead(BaseModule):
self.conv_pred.add_module('conv_out',
nn.Conv1d(prev_channel, conv_out_channel, 1))
def forward(self, feats_dict, sample_mod):
@property
def sample_mode(self):
if self.training:
sample_mode = self.train_cfg.sample_mode
else:
sample_mode = self.test_cfg.sample_mode
assert sample_mode in ['vote', 'seed', 'random']
return sample_mode
def forward(self, feats_dict):
"""Forward pass.
Args:
feats_dict (dict): Feature dict from backbone.
sample_mod (str): Sample mode for vote aggregation layer.
valid modes are "vote", "seed" and "random".
Returns:
dict: Predictions of primitive head.
"""
assert sample_mod in ['vote', 'seed', 'random']
sample_mode = self.sample_mode
seed_points = feats_dict['fp_xyz_net0'][-1]
seed_features = feats_dict['hd_feature']
......@@ -143,14 +155,14 @@ class PrimitiveHead(BaseModule):
results['vote_features_' + self.primitive_mode] = vote_features
# 2. aggregate vote_points
if sample_mod == 'vote':
if sample_mode == 'vote':
# use fps in vote_aggregation
sample_indices = None
elif sample_mod == 'seed':
elif sample_mode == 'seed':
# FPS on seed and choose the votes corresponding to the seeds
sample_indices = furthest_point_sample(seed_points,
self.num_proposal)
elif sample_mod == 'random':
elif sample_mode == 'random':
# Random sampling from the votes
batch_size, num_seed = seed_points.shape[:2]
sample_indices = torch.randint(
......@@ -185,63 +197,103 @@ class PrimitiveHead(BaseModule):
results['pred_' + self.primitive_mode + '_center'] = center
return results
def loss(self,
bbox_preds,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
img_metas=None,
gt_bboxes_ignore=None):
def loss(self, points: List[torch.Tensor], feats_dict: Dict[str,
torch.Tensor],
batch_data_samples: List[Det3DDataSample], **kwargs) -> dict:
"""
Args:
points (list[tensor]): Points cloud of multiple samples.
feats_dict (dict): Predictions from backbone or FPN.
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each sample and
corresponding annotations.
Returns:
dict: A dictionary of loss components.
"""
preds = self(feats_dict)
feats_dict.update(preds)
batch_gt_instance_3d = []
batch_gt_instances_ignore = []
batch_input_metas = []
batch_pts_semantic_mask = []
batch_pts_instance_mask = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
batch_gt_instance_3d.append(data_sample.gt_instances_3d)
batch_gt_instances_ignore.append(
data_sample.get('ignored_instances', None))
batch_pts_semantic_mask.append(
data_sample.gt_pts_seg.get('pts_semantic_mask', None))
batch_pts_instance_mask.append(
data_sample.gt_pts_seg.get('pts_instance_mask', None))
loss_inputs = (points, feats_dict, batch_gt_instance_3d)
losses = self.loss_by_feat(
*loss_inputs,
batch_pts_semantic_mask=batch_pts_semantic_mask,
batch_pts_instance_mask=batch_pts_instance_mask,
batch_gt_instances_ignore=batch_gt_instances_ignore,
)
return losses
def loss_by_feat(
self,
points: List[torch.Tensor],
feats_dict: dict,
batch_gt_instances_3d: List[InstanceData],
batch_pts_semantic_mask: Optional[List[torch.Tensor]] = None,
batch_pts_instance_mask: Optional[List[torch.Tensor]] = None,
**kwargs):
"""Compute loss.
Args:
bbox_preds (dict): Predictions from forward of primitive head.
points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
bboxes of each sample.
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
pts_semantic_mask (list[torch.Tensor]): Point-wise
semantic mask.
pts_instance_mask (list[torch.Tensor]): Point-wise
instance mask.
img_metas (list[dict]): Contain pcd and img's meta info.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding.
feats_dict (dict): Predictions of previous modules.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes`` and ``labels``
attributes.
batch_pts_semantic_mask (list[tensor]): Semantic mask
of points cloud. Defaults to None.
batch_pts_semantic_mask (list[tensor]): Instance mask
of points cloud. Defaults to None.
batch_input_metas (list[dict]): Contain pcd and img's meta info.
ret_target (bool): Return targets or not. Defaults to False.
Returns:
dict: Losses of Primitive Head.
"""
targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask,
bbox_preds)
targets = self.get_targets(points, feats_dict, batch_gt_instances_3d,
batch_pts_semantic_mask,
batch_pts_instance_mask)
(point_mask, point_offset, gt_primitive_center, gt_primitive_semantic,
gt_sem_cls_label, gt_primitive_mask) = targets
losses = {}
# Compute the loss of primitive existence flag
pred_flag = bbox_preds['pred_flag_' + self.primitive_mode]
flag_loss = self.objectness_loss(pred_flag, gt_primitive_mask.long())
pred_flag = feats_dict['pred_flag_' + self.primitive_mode]
flag_loss = self.loss_objectness(pred_flag, gt_primitive_mask.long())
losses['flag_loss_' + self.primitive_mode] = flag_loss
# calculate vote loss
vote_loss = self.vote_module.get_loss(
bbox_preds['seed_points'],
bbox_preds['vote_' + self.primitive_mode],
bbox_preds['seed_indices'], point_mask, point_offset)
feats_dict['seed_points'],
feats_dict['vote_' + self.primitive_mode],
feats_dict['seed_indices'], point_mask, point_offset)
losses['vote_loss_' + self.primitive_mode] = vote_loss
num_proposal = bbox_preds['aggregated_points_' +
num_proposal = feats_dict['aggregated_points_' +
self.primitive_mode].shape[1]
primitive_center = bbox_preds['center_' + self.primitive_mode]
primitive_center = feats_dict['center_' + self.primitive_mode]
if self.primitive_mode != 'line':
primitive_semantic = bbox_preds['size_residuals_' +
primitive_semantic = feats_dict['size_residuals_' +
self.primitive_mode].contiguous()
else:
primitive_semantic = None
semancitc_scores = bbox_preds['sem_cls_scores_' +
semancitc_scores = feats_dict['sem_cls_scores_' +
self.primitive_mode].transpose(2, 1)
gt_primitive_mask = gt_primitive_mask / \
......@@ -256,44 +308,61 @@ class PrimitiveHead(BaseModule):
return losses
def get_targets(self,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
bbox_preds=None):
def get_targets(
self,
points,
bbox_preds: Optional[dict] = None,
batch_gt_instances_3d: List[InstanceData] = None,
batch_pts_semantic_mask: List[torch.Tensor] = None,
batch_pts_instance_mask: List[torch.Tensor] = None,
):
"""Generate targets of primitive head.
Args:
points (list[torch.Tensor]): Points of each batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): Labels of each batch.
pts_semantic_mask (list[torch.Tensor]): Point-wise semantic
label of each batch.
pts_instance_mask (list[torch.Tensor]): Point-wise instance
label of each batch.
bbox_preds (dict): Predictions from forward of primitive head.
bbox_preds (torch.Tensor): Bounding box predictions of
primitive head.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
batch_pts_semantic_mask (list[tensor]): Semantic gt mask for
multiple images.
batch_pts_instance_mask (list[tensor]): Instance gt mask for
multiple images.
Returns:
tuple[torch.Tensor]: Targets of primitive head.
"""
for index in range(len(gt_labels_3d)):
if len(gt_labels_3d[index]) == 0:
fake_box = gt_bboxes_3d[index].tensor.new_zeros(
1, gt_bboxes_3d[index].tensor.shape[-1])
gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box)
gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1)
if pts_semantic_mask is None:
pts_semantic_mask = [None for i in range(len(gt_labels_3d))]
pts_instance_mask = [None for i in range(len(gt_labels_3d))]
batch_gt_labels_3d = [
gt_instances_3d.labels_3d
for gt_instances_3d in batch_gt_instances_3d
]
batch_gt_bboxes_3d = [
gt_instances_3d.bboxes_3d
for gt_instances_3d in batch_gt_instances_3d
]
for index in range(len(batch_gt_labels_3d)):
if len(batch_gt_labels_3d[index]) == 0:
fake_box = batch_gt_bboxes_3d[index].tensor.new_zeros(
1, batch_gt_bboxes_3d[index].tensor.shape[-1])
batch_gt_bboxes_3d[index] = batch_gt_bboxes_3d[index].new_box(
fake_box)
batch_gt_labels_3d[index] = batch_gt_labels_3d[
index].new_zeros(1)
if batch_pts_semantic_mask is None:
batch_pts_semantic_mask = [
None for _ in range(len(batch_gt_labels_3d))
]
batch_pts_instance_mask = [
None for _ in range(len(batch_gt_labels_3d))
]
(point_mask, point_sem,
point_offset) = multi_apply(self.get_targets_single, points,
gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask)
batch_gt_bboxes_3d, batch_gt_labels_3d,
batch_pts_semantic_mask,
batch_pts_instance_mask)
point_mask = torch.stack(point_mask)
point_sem = torch.stack(point_sem)
......@@ -759,7 +828,7 @@ class PrimitiveHead(BaseModule):
vote_xyz_reshape = primitive_center.view(batch_size * num_proposal, -1,
3)
center_loss = self.center_loss(
center_loss = self.loss_center(
vote_xyz_reshape,
gt_primitive_center,
dst_weight=gt_primitive_mask.view(batch_size * num_proposal, 1))[1]
......@@ -767,7 +836,7 @@ class PrimitiveHead(BaseModule):
if self.primitive_mode != 'line':
size_xyz_reshape = primitive_semantic.view(
batch_size * num_proposal, -1, self.num_dims).contiguous()
size_loss = self.semantic_reg_loss(
size_loss = self.loss_semantic_reg(
size_xyz_reshape,
gt_primitive_semantic,
dst_weight=gt_primitive_mask.view(batch_size * num_proposal,
......@@ -776,7 +845,7 @@ class PrimitiveHead(BaseModule):
size_loss = center_loss.new_tensor(0.0)
# Semantic cls loss
sem_cls_loss = self.semantic_cls_loss(
sem_cls_loss = self.loss_semantic_cls(
semantic_scores, gt_sem_cls_label, weight=gt_primitive_mask)
return center_loss, size_loss, sem_cls_loss
......
import unittest
import torch
from mmengine import DefaultScope
from mmdet3d.registry import MODELS
from tests.utils.model_utils import (_create_detector_inputs,
_get_detector_cfg, _setup_seed)
class TestH3D(unittest.TestCase):
def test_h3dnet(self):
import mmdet3d.models
assert hasattr(mmdet3d.models, 'H3DNet')
DefaultScope.get_instance('test_H3DNet', scope_name='mmdet3d')
_setup_seed(0)
voxel_net_cfg = _get_detector_cfg(
'h3dnet/h3dnet_3x8_scannet-3d-18class.py')
model = MODELS.build(voxel_net_cfg)
num_gt_instance = 5
data = [
_create_detector_inputs(
num_gt_instance=num_gt_instance,
points_feat_dim=4,
bboxes_3d_type='depth',
with_pts_semantic_mask=True,
with_pts_instance_mask=True)
]
if torch.cuda.is_available():
model = model.cuda()
# test simple_test
with torch.no_grad():
batch_inputs, data_samples = model.data_preprocessor(
data, True)
results = model.forward(
batch_inputs, data_samples, mode='predict')
self.assertEqual(len(results), len(data))
self.assertIn('bboxes_3d', results[0].pred_instances_3d)
self.assertIn('scores_3d', results[0].pred_instances_3d)
self.assertIn('labels_3d', results[0].pred_instances_3d)
# save the memory
with torch.no_grad():
losses = model.forward(batch_inputs, data_samples, mode='loss')
self.assertGreater(losses['vote_loss'], 0)
self.assertGreater(losses['objectness_loss'], 0)
self.assertGreater(losses['center_loss'], 0)
......@@ -7,7 +7,8 @@ import numpy as np
import torch
from mmengine import InstanceData
from mmdet3d.core import Det3DDataSample, LiDARInstance3DBoxes, PointData
from mmdet3d.core import (CameraInstance3DBoxes, DepthInstance3DBoxes,
Det3DDataSample, LiDARInstance3DBoxes, PointData)
def _setup_seed(seed):
......@@ -71,19 +72,24 @@ def _get_detector_cfg(fname):
return model
def _create_detector_inputs(
seed=0,
with_points=True,
with_img=False,
num_gt_instance=20,
num_points=10,
points_feat_dim=4,
num_classes=3,
gt_bboxes_dim=7,
with_pts_semantic_mask=False,
with_pts_instance_mask=False,
):
def _create_detector_inputs(seed=0,
with_points=True,
with_img=False,
num_gt_instance=20,
num_points=10,
points_feat_dim=4,
num_classes=3,
gt_bboxes_dim=7,
with_pts_semantic_mask=False,
with_pts_instance_mask=False,
bboxes_3d_type='lidar'):
_setup_seed(seed)
assert bboxes_3d_type in ('lidar', 'depth', 'cam')
bbox_3d_class = {
'lidar': LiDARInstance3DBoxes,
'depth': DepthInstance3DBoxes,
'cam': CameraInstance3DBoxes
}
if with_points:
points = torch.rand([num_points, points_feat_dim])
else:
......@@ -93,12 +99,13 @@ def _create_detector_inputs(
else:
img = None
inputs_dict = dict(img=img, points=points)
gt_instance_3d = InstanceData()
gt_instance_3d.bboxes_3d = LiDARInstance3DBoxes(
gt_instance_3d.bboxes_3d = bbox_3d_class[bboxes_3d_type](
torch.rand([num_gt_instance, gt_bboxes_dim]), box_dim=gt_bboxes_dim)
gt_instance_3d.labels_3d = torch.randint(0, num_classes, [num_gt_instance])
data_sample = Det3DDataSample(
metainfo=dict(box_type_3d=LiDARInstance3DBoxes))
metainfo=dict(box_type_3d=bbox_3d_class[bboxes_3d_type]))
data_sample.gt_instances_3d = gt_instance_3d
data_sample.gt_pts_seg = PointData()
if with_pts_instance_mask:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment