Commit 0e17beab authored by jshilong's avatar jshilong Committed by ChaimZhu
Browse files

[REfactor]Refactor H3D

parent 9ebb75da
......@@ -30,7 +30,7 @@ primitive_z_cfg = dict(
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
type='CrossEntropyLoss',
type='mmdet.CrossEntropyLoss',
class_weight=[0.4, 0.6],
reduction='mean',
loss_weight=30.0),
......@@ -47,14 +47,16 @@ primitive_z_cfg = dict(
loss_src_weight=0.5,
loss_dst_weight=0.5),
semantic_cls_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
train_cfg=dict(
sample_mode='vote',
dist_thresh=0.2,
var_thresh=1e-2,
lower_thresh=1e-6,
num_point=100,
num_point_line=10,
line_thresh=0.2))
line_thresh=0.2),
test_cfg=dict(sample_mode='seed'))
primitive_xy_cfg = dict(
type='PrimitiveHead',
......@@ -88,7 +90,7 @@ primitive_xy_cfg = dict(
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
type='CrossEntropyLoss',
type='mmdet.CrossEntropyLoss',
class_weight=[0.4, 0.6],
reduction='mean',
loss_weight=30.0),
......@@ -105,14 +107,16 @@ primitive_xy_cfg = dict(
loss_src_weight=0.5,
loss_dst_weight=0.5),
semantic_cls_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
train_cfg=dict(
sample_mode='vote',
dist_thresh=0.2,
var_thresh=1e-2,
lower_thresh=1e-6,
num_point=100,
num_point_line=10,
line_thresh=0.2))
line_thresh=0.2),
test_cfg=dict(sample_mode='seed'))
primitive_line_cfg = dict(
type='PrimitiveHead',
......@@ -146,7 +150,7 @@ primitive_line_cfg = dict(
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
type='CrossEntropyLoss',
type='mmdet.CrossEntropyLoss',
class_weight=[0.4, 0.6],
reduction='mean',
loss_weight=30.0),
......@@ -163,17 +167,20 @@ primitive_line_cfg = dict(
loss_src_weight=1.0,
loss_dst_weight=1.0),
semantic_cls_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=2.0),
type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=2.0),
train_cfg=dict(
sample_mode='vote',
dist_thresh=0.2,
var_thresh=1e-2,
lower_thresh=1e-6,
num_point=100,
num_point_line=10,
line_thresh=0.2))
line_thresh=0.2),
test_cfg=dict(sample_mode='seed'))
model = dict(
type='H3DNet',
data_preprocessor=dict(type='Det3DDataPreprocessor'),
backbone=dict(
type='MultiBackbone',
num_streams=4,
......@@ -221,10 +228,8 @@ model = dict(
normalize_xyz=True),
pred_layer_cfg=dict(
in_channels=128, shared_conv_channels=(128, 128), bias=True),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
type='CrossEntropyLoss',
type='mmdet.CrossEntropyLoss',
class_weight=[0.2, 0.8],
reduction='sum',
loss_weight=5.0),
......@@ -235,15 +240,15 @@ model = dict(
loss_src_weight=10.0,
loss_dst_weight=10.0),
dir_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
dir_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0),
type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=10.0),
size_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
size_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0),
type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=10.0),
semantic_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0)),
type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0)),
roi_head=dict(
type='H3DRoIHead',
primitive_list=[primitive_z_cfg, primitive_xy_cfg, primitive_line_cfg],
......@@ -267,7 +272,6 @@ model = dict(
mlp_channels=[128 + 12, 128, 64, 32],
use_xyz=True,
normalize_xyz=True),
feat_channels=(128, 128),
primitive_refine_channels=[128, 128, 128],
upper_thresh=100.0,
surface_thresh=0.5,
......@@ -275,7 +279,7 @@ model = dict(
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
type='CrossEntropyLoss',
type='mmdet.CrossEntropyLoss',
class_weight=[0.2, 0.8],
reduction='sum',
loss_weight=5.0),
......@@ -286,41 +290,47 @@ model = dict(
loss_src_weight=10.0,
loss_dst_weight=10.0),
dir_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1),
type='mmdet.CrossEntropyLoss',
reduction='sum',
loss_weight=0.1),
dir_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0),
type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=10.0),
size_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1),
type='mmdet.CrossEntropyLoss',
reduction='sum',
loss_weight=0.1),
size_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0),
type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=10.0),
semantic_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1),
type='mmdet.CrossEntropyLoss',
reduction='sum',
loss_weight=0.1),
cues_objectness_loss=dict(
type='CrossEntropyLoss',
type='mmdet.CrossEntropyLoss',
class_weight=[0.3, 0.7],
reduction='mean',
loss_weight=5.0),
cues_semantic_loss=dict(
type='CrossEntropyLoss',
type='mmdet.CrossEntropyLoss',
class_weight=[0.3, 0.7],
reduction='mean',
loss_weight=5.0),
proposal_objectness_loss=dict(
type='CrossEntropyLoss',
type='mmdet.CrossEntropyLoss',
class_weight=[0.2, 0.8],
reduction='none',
loss_weight=5.0),
primitive_center_loss=dict(
type='MSELoss', reduction='none', loss_weight=1.0))),
type='mmdet.MSELoss', reduction='none', loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mod='vote'),
pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mode='vote'),
rpn_proposal=dict(use_nms=False),
rcnn=dict(
pos_distance_thr=0.3,
neg_distance_thr=0.6,
sample_mod='vote',
sample_mode='vote',
far_threshold=0.6,
near_threshold=0.3,
mask_surface_threshold=0.3,
......@@ -329,13 +339,13 @@ model = dict(
label_line_threshold=0.3)),
test_cfg=dict(
rpn=dict(
sample_mod='seed',
sample_mode='seed',
nms_thr=0.25,
score_thr=0.05,
per_class_proposal=True,
use_nms=False),
rcnn=dict(
sample_mod='seed',
sample_mode='seed',
nms_thr=0.25,
score_thr=0.05,
per_class_proposal=True)))
_base_ = [
'../_base_/datasets/scannet-3d-18class.py', '../_base_/models/h3dnet.py',
'../_base_/schedules/schedule_3x.py', '../_base_/default_runtime.py'
]
# model settings
model = dict(
rpn_head=dict(
num_classes=18,
bbox_coder=dict(
type='PartialBinBasedBBoxCoder',
num_sizes=18,
num_dir_bins=24,
with_rot=False,
mean_sizes=[[0.76966727, 0.8116021, 0.92573744],
[1.876858, 1.8425595, 1.1931566],
[0.61328, 0.6148609, 0.7182701],
[1.3955007, 1.5121545, 0.83443564],
[0.97949594, 1.0675149, 0.6329687],
[0.531663, 0.5955577, 1.7500148],
[0.9624706, 0.72462326, 1.1481868],
[0.83221924, 1.0490936, 1.6875663],
[0.21132214, 0.4206159, 0.5372846],
[1.4440073, 1.8970833, 0.26985747],
[1.0294262, 1.4040797, 0.87554324],
[1.3766412, 0.65521795, 1.6813129],
[0.6650819, 0.71111923, 1.298853],
[0.41999173, 0.37906948, 1.7513971],
[0.59359556, 0.5912492, 0.73919016],
[0.50867593, 0.50656086, 0.30136237],
[1.1511526, 1.0546296, 0.49706793],
[0.47535285, 0.49249494, 0.5802117]])),
roi_head=dict(
bbox_head=dict(
num_classes=18,
bbox_coder=dict(
type='PartialBinBasedBBoxCoder',
num_sizes=18,
num_dir_bins=24,
with_rot=False,
mean_sizes=[[0.76966727, 0.8116021, 0.92573744],
[1.876858, 1.8425595, 1.1931566],
[0.61328, 0.6148609, 0.7182701],
[1.3955007, 1.5121545, 0.83443564],
[0.97949594, 1.0675149, 0.6329687],
[0.531663, 0.5955577, 1.7500148],
[0.9624706, 0.72462326, 1.1481868],
[0.83221924, 1.0490936, 1.6875663],
[0.21132214, 0.4206159, 0.5372846],
[1.4440073, 1.8970833, 0.26985747],
[1.0294262, 1.4040797, 0.87554324],
[1.3766412, 0.65521795, 1.6813129],
[0.6650819, 0.71111923, 1.298853],
[0.41999173, 0.37906948, 1.7513971],
[0.59359556, 0.5912492, 0.73919016],
[0.50867593, 0.50656086, 0.30136237],
[1.1511526, 1.0546296, 0.49706793],
[0.47535285, 0.49249494, 0.5802117]]))))
train_dataloader = dict(
batch_size=3,
num_workers=2,
)
# yapf:disable
default_hooks = dict(
logger=dict(type='LoggerHook', interval=30)
)
# yapf:enable
......@@ -57,8 +57,13 @@ model = dict(
[1.1511526, 1.0546296, 0.49706793],
[0.47535285, 0.49249494, 0.5802117]]))))
data = dict(samples_per_gpu=3, workers_per_gpu=2)
train_dataloader = dict(
batch_size=3,
num_workers=2,
)
# yapf:disable
log_config = dict(interval=30)
default_hooks = dict(
logger=dict(type='LoggerHook', interval=30)
)
# yapf:enable
......@@ -5,11 +5,9 @@ from typing import Callable, List, Optional, Union
import numpy as np
from mmdet3d.core import show_result
from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.registry import DATASETS
from .det3d_dataset import Det3DDataset
from .pipelines import Compose
from .seg3d_dataset import Seg3DDataset
......@@ -151,46 +149,6 @@ class ScanNetDataset(Det3DDataset):
return ann_info
def _build_default_pipeline(self):
"""Build the default pipeline for this dataset."""
pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='DefaultFormatBundle3D',
class_names=self.CLASSES,
with_label=False),
dict(type='Collect3D', keys=['points'])
]
return Compose(pipeline)
def show(self, results, out_dir, show=True, pipeline=None):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
show (bool): Visualize the results online.
pipeline (list[dict], optional): raw data loading for showing.
Default: None.
"""
assert out_dir is not None, 'Expect out_dir, got none.'
pipeline = self._get_pipeline(pipeline)
for i, result in enumerate(results):
data_info = self.get_data_info[i]
pts_path = data_info['lidar_points']['lidar_path']
file_name = osp.split(pts_path)[-1].split('.')[0]
points = self._extract_data(i, pipeline, 'points').numpy()
gt_bboxes = self.get_ann_info(i)['gt_bboxes_3d'].tensor.numpy()
pred_bboxes = result['boxes_3d'].tensor.numpy()
show_result(points, gt_bboxes, pred_bboxes, out_dir, file_name,
show)
@DATASETS.register_module()
class ScanNetSegDataset(Seg3DDataset):
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from mmcv.ops import furthest_point_sample
from mmcv.runner import BaseModule
from mmengine import ConfigDict, InstanceData
from torch import Tensor
from torch.nn import functional as F
from mmdet3d.core.post_processing import aligned_3d_nms
......@@ -161,7 +162,7 @@ class VoteHead(BaseModule):
points: List[torch.Tensor],
feats_dict: Dict[str, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
rescale=True,
use_nms: bool = True,
**kwargs) -> List[InstanceData]:
"""
Args:
......@@ -169,8 +170,8 @@ class VoteHead(BaseModule):
feats_dict (dict): Features from FPN or backbone..
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes meta information of data.
rescale (bool): Whether rescale the resutls to
the original scale.
use_nms (bool): Whether do the nms for predictions.
Defaults to True.
Returns:
list[:obj:`InstanceData`]: List of processed predictions. Each
......@@ -178,6 +179,9 @@ class VoteHead(BaseModule):
scores and labels.
"""
preds_dict = self(feats_dict)
# `preds_dict` can be used in H3DNET
feats_dict.update(preds_dict)
batch_size = len(batch_data_samples)
batch_input_metas = []
for batch_index in range(batch_size):
......@@ -185,12 +189,73 @@ class VoteHead(BaseModule):
batch_input_metas.append(metainfo)
results_list = self.predict_by_feat(
points, preds_dict, batch_input_metas, rescale=rescale, **kwargs)
points, preds_dict, batch_input_metas, use_nms=use_nms, **kwargs)
return results_list
def loss(self, points: List[torch.Tensor], feats_dict: Dict[str,
torch.Tensor],
batch_data_samples: List[Det3DDataSample], **kwargs) -> dict:
def loss_and_predict(self,
points: List[torch.Tensor],
feats_dict: Dict[str, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
ret_target: bool = False,
proposal_cfg: dict = None,
**kwargs) -> Tuple:
"""
Args:
points (list[tensor]): Points cloud of multiple samples.
feats_dict (dict): Predictions from backbone or FPN.
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each sample and
corresponding annotations.
ret_target (bool): Whether return the assigned target.
Defaults to False.
proposal_cfg (dict): Configure for proposal process.
Defaults to True.
Returns:
tuple: Contains loss and predictions after post-process.
"""
preds_dict = self.forward(feats_dict)
feats_dict.update(preds_dict)
batch_gt_instance_3d = []
batch_gt_instances_ignore = []
batch_input_metas = []
batch_pts_semantic_mask = []
batch_pts_instance_mask = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
batch_gt_instance_3d.append(data_sample.gt_instances_3d)
batch_gt_instances_ignore.append(
data_sample.get('ignored_instances', None))
batch_pts_semantic_mask.append(
data_sample.gt_pts_seg.get('pts_semantic_mask', None))
batch_pts_instance_mask.append(
data_sample.gt_pts_seg.get('pts_instance_mask', None))
loss_inputs = (points, preds_dict, batch_gt_instance_3d)
losses = self.loss_by_feat(
*loss_inputs,
batch_pts_semantic_mask=batch_pts_semantic_mask,
batch_pts_instance_mask=batch_pts_instance_mask,
batch_input_metas=batch_input_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore,
ret_target=ret_target,
**kwargs)
results_list = self.predict_by_feat(
points,
preds_dict,
batch_input_metas,
use_nms=proposal_cfg.use_nms,
**kwargs)
return losses, results_list
def loss(self,
points: List[torch.Tensor],
feats_dict: Dict[str, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
ret_target: bool = False,
**kwargs) -> dict:
"""
Args:
points (list[tensor]): Points cloud of multiple samples.
......@@ -198,6 +263,8 @@ class VoteHead(BaseModule):
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each sample and
corresponding annotations.
ret_target (bool): Whether return the assigned target.
Defaults to False.
Returns:
dict: A dictionary of loss components.
......@@ -224,7 +291,9 @@ class VoteHead(BaseModule):
batch_pts_semantic_mask=batch_pts_semantic_mask,
batch_pts_instance_mask=batch_pts_instance_mask,
batch_input_metas=batch_input_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore)
batch_gt_instances_ignore=batch_gt_instances_ignore,
ret_target=ret_target,
**kwargs)
return losses
def forward(self, feat_dict: dict) -> dict:
......@@ -330,7 +399,7 @@ class VoteHead(BaseModule):
batch_pts_semantic_mask (list[tensor]): Instance mask
of points cloud. Defaults to None.
batch_input_metas (list[dict]): Contain pcd and img's meta info.
ret_target (bool): Return targets or not.
ret_target (bool): Return targets or not. Defaults to False.
Returns:
dict: Losses of Votenet.
......@@ -671,9 +740,10 @@ class VoteHead(BaseModule):
while using vote head in rpn stage.
Returns:
list[:obj:`InstanceData`]: List of processed predictions. Each
InstanceData cantains 3d Bounding boxes and corresponding
scores and labels.
list[:obj:`InstanceData`] or Tensor: Return list of processed
predictions when `use_nms` is True. Each InstanceData cantains
3d Bounding boxes and corresponding scores and labels.
Return raw bboxes when `use_nms` is False.
"""
# decode boxes
stack_points = torch.stack(points)
......@@ -683,9 +753,9 @@ class VoteHead(BaseModule):
batch_size = bbox3d.shape[0]
results_list = list()
for b in range(batch_size):
temp_results = InstanceData()
if use_nms:
if use_nms:
for b in range(batch_size):
temp_results = InstanceData()
bbox_selected, score_selected, labels = \
self.multiclass_nms_single(obj_scores[b],
sem_scores[b],
......@@ -700,20 +770,15 @@ class VoteHead(BaseModule):
temp_results.scores_3d = score_selected
temp_results.labels_3d = labels
results_list.append(temp_results)
else:
bbox = batch_input_metas[b]['box_type_3d'](
bbox_selected,
box_dim=bbox_selected.shape[-1],
with_yaw=self.bbox_coder.with_rot)
temp_results.bboxes_3d = bbox
temp_results.obj_scores_3d = obj_scores[b]
temp_results.sem_scores_3d = obj_scores[b]
results_list.append(temp_results)
return results_list
return results_list
else:
# TODO unify it when refactor the Augtest
return bbox3d
def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points,
input_meta):
def multiclass_nms_single(self, obj_scores: Tensor, sem_scores: Tensor,
bbox: Tensor, points: Tensor,
input_meta: dict) -> Tuple:
"""Multi-class nms in single batch.
Args:
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch
from torch import Tensor
from mmdet3d.core import merge_aug_bboxes_3d
from mmdet3d.registry import MODELS
from ...core import Det3DDataSample
from .two_stage import TwoStage3DDetector
......@@ -11,17 +14,33 @@ class H3DNet(TwoStage3DDetector):
r"""H3DNet model.
Please refer to the `paper <https://arxiv.org/abs/2006.05682>`_
Args:
backbone (dict): Config dict of detector's backbone.
neck (dict, optional): Config dict of neck. Defaults to None.
rpn_head (dict, optional): Config dict of rpn head. Defaults to None.
roi_head (dict, optional): Config dict of roi head. Defaults to None.
train_cfg (dict, optional): Config dict of training hyper-parameters.
Defaults to None.
test_cfg (dict, optional): Config dict of test hyper-parameters.
Defaults to None.
init_cfg (dict, optional): the config to control the
initialization. Default to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`BaseDataPreprocessor`. it usually includes,
``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
"""
def __init__(self,
backbone,
neck=None,
rpn_head=None,
roi_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
backbone: dict,
neck: Optional[dict] = None,
rpn_head: Optional[dict] = None,
roi_head: Optional[dict] = None,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None,
data_preprocessor: Optional[dict] = None,
**kwargs) -> None:
super(H3DNet, self).__init__(
backbone=backbone,
neck=neck,
......@@ -29,148 +48,110 @@ class H3DNet(TwoStage3DDetector):
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
init_cfg=init_cfg)
def forward_train(self,
points,
img_metas,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
gt_bboxes_ignore=None):
"""Forward of training.
init_cfg=init_cfg,
data_preprocessor=data_preprocessor,
**kwargs)
def extract_feat(self, batch_inputs_dict: dict) -> None:
"""Directly extract features from the backbone+neck.
Args:
points (list[torch.Tensor]): Points of each batch.
img_metas (list): Image metas.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
pts_semantic_mask (list[torch.Tensor]): point-wise semantic
label of each batch.
pts_instance_mask (list[torch.Tensor]): point-wise instance
label of each batch.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding.
batch_inputs_dict (dict): The model input dict which include
'points'.
- points (list[torch.Tensor]): Point cloud of each sample.
Returns:
dict: Losses.
dict: Dict of feature.
"""
points_cat = torch.stack(points)
stack_points = torch.stack(batch_inputs_dict['points'])
x = self.backbone(stack_points)
if self.with_neck:
x = self.neck(x)
return x
def loss(self, batch_inputs_dict: Dict[str, Union[List, Tensor]],
batch_data_samples: List[Det3DDataSample], **kwargs) -> dict:
"""
Args:
batch_inputs_dict (dict): The model input dict which include
'points' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
feats_dict = self.extract_feat(batch_inputs_dict)
feats_dict = self.extract_feat(points_cat)
feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]]
feats_dict['fp_features'] = [feats_dict['hd_feature']]
feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]]
losses = dict()
if self.with_rpn:
rpn_outs = self.rpn_head(feats_dict, self.train_cfg.rpn.sample_mod)
feats_dict.update(rpn_outs)
rpn_loss_inputs = (points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, img_metas)
rpn_losses = self.rpn_head.loss(
rpn_outs,
*rpn_loss_inputs,
gt_bboxes_ignore=gt_bboxes_ignore,
ret_target=True)
feats_dict['targets'] = rpn_losses.pop('targets')
losses.update(rpn_losses)
# Generate rpn proposals
proposal_cfg = self.train_cfg.get('rpn_proposal',
self.test_cfg.rpn)
proposal_inputs = (points, rpn_outs, img_metas)
proposal_list = self.rpn_head.get_bboxes(
*proposal_inputs, use_nms=proposal_cfg.use_nms)
feats_dict['proposal_list'] = proposal_list
# note, the feats_dict would be added new key & value in rpn_head
rpn_losses, rpn_proposals = self.rpn_head.loss_and_predict(
batch_inputs_dict['points'],
feats_dict,
batch_data_samples,
ret_target=True,
proposal_cfg=proposal_cfg)
feats_dict['targets'] = rpn_losses.pop('targets')
losses.update(rpn_losses)
feats_dict['rpn_proposals'] = rpn_proposals
else:
raise NotImplementedError
roi_losses = self.roi_head.forward_train(feats_dict, img_metas, points,
gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask,
pts_instance_mask,
gt_bboxes_ignore)
roi_losses = self.roi_head.loss(batch_inputs_dict['points'],
feats_dict, batch_data_samples,
**kwargs)
losses.update(roi_losses)
return losses
def simple_test(self, points, img_metas, imgs=None, rescale=False):
"""Forward of testing.
def predict(
self, batch_input_dict: Dict,
batch_data_samples: List[Det3DDataSample]
) -> List[Det3DDataSample]:
"""Get model predictions.
Args:
points (list[torch.Tensor]): Points of each sample.
img_metas (list): Image metas.
rescale (bool): Whether to rescale results.
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each sample and
corresponding annotations.
Returns:
list: Predicted 3d boxes.
"""
points_cat = torch.stack(points)
feats_dict = self.extract_feat(points_cat)
feats_dict = self.extract_feat(batch_input_dict)
feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]]
feats_dict['fp_features'] = [feats_dict['hd_feature']]
feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]]
if self.with_rpn:
proposal_cfg = self.test_cfg.rpn
rpn_outs = self.rpn_head(feats_dict, proposal_cfg.sample_mod)
feats_dict.update(rpn_outs)
# Generate rpn proposals
proposal_list = self.rpn_head.get_bboxes(
points, rpn_outs, img_metas, use_nms=proposal_cfg.use_nms)
feats_dict['proposal_list'] = proposal_list
rpn_proposals = self.rpn_head.predict(
batch_input_dict['points'],
feats_dict,
batch_data_samples,
use_nms=proposal_cfg.use_nms)
feats_dict['rpn_proposals'] = rpn_proposals
else:
raise NotImplementedError
return self.roi_head.simple_test(
feats_dict, img_metas, points_cat, rescale=rescale)
def aug_test(self, points, img_metas, imgs=None, rescale=False):
"""Test with augmentation."""
points_cat = [torch.stack(pts) for pts in points]
feats_dict = self.extract_feats(points_cat, img_metas)
for feat_dict in feats_dict:
feat_dict['fp_xyz'] = [feat_dict['fp_xyz_net0'][-1]]
feat_dict['fp_features'] = [feat_dict['hd_feature']]
feat_dict['fp_indices'] = [feat_dict['fp_indices_net0'][-1]]
# only support aug_test for one sample
aug_bboxes = []
for feat_dict, pts_cat, img_meta in zip(feats_dict, points_cat,
img_metas):
if self.with_rpn:
proposal_cfg = self.test_cfg.rpn
rpn_outs = self.rpn_head(feat_dict, proposal_cfg.sample_mod)
feat_dict.update(rpn_outs)
# Generate rpn proposals
proposal_list = self.rpn_head.get_bboxes(
points, rpn_outs, img_metas, use_nms=proposal_cfg.use_nms)
feat_dict['proposal_list'] = proposal_list
else:
raise NotImplementedError
bbox_results = self.roi_head.simple_test(
feat_dict,
self.test_cfg.rcnn.sample_mod,
img_meta,
pts_cat,
rescale=rescale)
aug_bboxes.append(bbox_results)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.bbox_head.test_cfg)
return [merged_bboxes]
def extract_feats(self, points, img_metas):
"""Extract features of multiple samples."""
return [
self.extract_feat(pts, img_meta)
for pts, img_meta in zip(points, img_metas)
]
results_list = self.roi_head.predict(
batch_input_dict['points'],
feats_dict,
batch_data_samples,
suffix='_optimized')
return self.convert_to_datasample(results_list)
......@@ -56,12 +56,12 @@ class PointRCNN(TwoStage3DDetector):
x = self.neck(x)
return x
def forward_train(self, points, img_metas, gt_bboxes_3d, gt_labels_3d):
def forward_train(self, points, input_metas, gt_bboxes_3d, gt_labels_3d):
"""Forward of training.
Args:
points (list[torch.Tensor]): Points of each batch.
img_metas (list[dict]): Meta information of each sample.
input_metas (list[dict]): Meta information of each sample.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
......@@ -69,8 +69,8 @@ class PointRCNN(TwoStage3DDetector):
dict: Losses.
"""
losses = dict()
points_cat = torch.stack(points)
x = self.extract_feat(points_cat)
stack_points = torch.stack(points)
x = self.extract_feat(stack_points)
# features for rcnn
backbone_feats = x['fp_features'].clone()
......@@ -85,11 +85,11 @@ class PointRCNN(TwoStage3DDetector):
points=points,
gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=gt_labels_3d,
img_metas=img_metas)
input_metas=input_metas)
losses.update(rpn_loss)
bbox_list = self.rpn_head.get_bboxes(points_cat, bbox_preds, cls_preds,
img_metas)
bbox_list = self.rpn_head.get_bboxes(stack_points, bbox_preds,
cls_preds, input_metas)
proposal_list = [
dict(
boxes_3d=bboxes,
......@@ -100,7 +100,7 @@ class PointRCNN(TwoStage3DDetector):
]
rcnn_feats.update({'points_cls_preds': cls_preds})
roi_losses = self.roi_head.forward_train(rcnn_feats, img_metas,
roi_losses = self.roi_head.forward_train(rcnn_feats, input_metas,
proposal_list, gt_bboxes_3d,
gt_labels_3d)
losses.update(roi_losses)
......@@ -121,9 +121,9 @@ class PointRCNN(TwoStage3DDetector):
Returns:
list: Predicted 3d boxes.
"""
points_cat = torch.stack(points)
stack_points = torch.stack(points)
x = self.extract_feat(points_cat)
x = self.extract_feat(stack_points)
# features for rcnn
backbone_feats = x['fp_features'].clone()
backbone_xyz = x['fp_xyz'].clone()
......@@ -132,7 +132,7 @@ class PointRCNN(TwoStage3DDetector):
rcnn_feats.update({'points_cls_preds': cls_preds})
bbox_list = self.rpn_head.get_bboxes(
points_cat, bbox_preds, cls_preds, img_metas, rescale=rescale)
stack_points, bbox_preds, cls_preds, img_metas, rescale=rescale)
proposal_list = [
dict(
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet3d.core.bbox import bbox3d2result
from typing import Dict, List
from mmengine import InstanceData
from torch import Tensor
from mmdet3d.registry import MODELS
from ...core import Det3DDataSample
from .base_3droi_head import Base3DRoIHead
......@@ -16,17 +21,15 @@ class H3DRoIHead(Base3DRoIHead):
"""
def __init__(self,
primitive_list,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
primitive_list: List[dict],
bbox_head: dict = None,
train_cfg: dict = None,
test_cfg: dict = None,
init_cfg: dict = None):
super(H3DRoIHead, self).__init__(
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
init_cfg=init_cfg)
# Primitive module
assert len(primitive_list) == 3
......@@ -39,8 +42,14 @@ class H3DRoIHead(Base3DRoIHead):
one."""
pass
def init_bbox_head(self, bbox_head):
"""Initialize box head."""
def init_bbox_head(self, dummy_args, bbox_head):
"""Initialize box head.
Args:
dummy_args (optional): Just to compatible with
the interface in base class
bbox_head (dict): Config for bbox head.
"""
bbox_head['train_cfg'] = self.train_cfg
bbox_head['test_cfg'] = self.test_cfg
self.bbox_head = MODELS.build(bbox_head)
......@@ -49,111 +58,73 @@ class H3DRoIHead(Base3DRoIHead):
"""Initialize assigner and sampler."""
pass
def forward_train(self,
feats_dict,
img_metas,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask,
pts_instance_mask,
gt_bboxes_ignore=None):
def loss(self, points: List[Tensor], feats_dict: dict,
batch_data_samples: List[Det3DDataSample], **kwargs):
"""Training forward function of PartAggregationROIHead.
Args:
feats_dict (dict): Contains features from the first stage.
img_metas (list[dict]): Contain pcd and img's meta info.
points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
bboxes of each sample.
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
pts_semantic_mask (list[torch.Tensor]): Point-wise
semantic mask.
pts_instance_mask (list[torch.Tensor]): Point-wise
instance mask.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding boxes to ignore.
points (list[torch.Tensor]): Point cloud of each sample.
feats_dict (dict): Dict of feature.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`.
Returns:
dict: losses from each head.
"""
losses = dict()
sample_mod = self.train_cfg.sample_mod
assert sample_mod in ['vote', 'seed', 'random']
result_z = self.primitive_z(feats_dict, sample_mod)
feats_dict.update(result_z)
result_xy = self.primitive_xy(feats_dict, sample_mod)
feats_dict.update(result_xy)
result_line = self.primitive_line(feats_dict, sample_mod)
feats_dict.update(result_line)
primitive_loss_inputs = (feats_dict, points, gt_bboxes_3d,
gt_labels_3d, pts_semantic_mask,
pts_instance_mask, img_metas,
gt_bboxes_ignore)
primitive_loss_inputs = (points, feats_dict, batch_data_samples)
# note the feats_dict would be added new key and value in each head.
loss_z = self.primitive_z.loss(*primitive_loss_inputs)
losses.update(loss_z)
loss_xy = self.primitive_xy.loss(*primitive_loss_inputs)
losses.update(loss_xy)
loss_line = self.primitive_line.loss(*primitive_loss_inputs)
losses.update(loss_z)
losses.update(loss_xy)
losses.update(loss_line)
targets = feats_dict.pop('targets')
bbox_results = self.bbox_head(feats_dict, sample_mod)
feats_dict.update(bbox_results)
bbox_loss = self.bbox_head.loss(feats_dict, points, gt_bboxes_3d,
gt_labels_3d, pts_semantic_mask,
pts_instance_mask, img_metas, targets,
gt_bboxes_ignore)
bbox_loss = self.bbox_head.loss(
points,
feats_dict,
rpn_targets=targets,
batch_data_samples=batch_data_samples)
losses.update(bbox_loss)
return losses
def simple_test(self, feats_dict, img_metas, points, rescale=False):
"""Simple testing forward function of PartAggregationROIHead.
Note:
This function assumes that the batch size is 1
def predict(self,
points: List[Tensor],
feats_dict: Dict[str, Tensor],
batch_data_samples: List[Det3DDataSample],
suffix='_optimized',
**kwargs) -> List[InstanceData]:
"""
Args:
feats_dict (dict): Contains features from the first stage.
img_metas (list[dict]): Contain pcd and img's meta info.
points (torch.Tensor): Input points.
rescale (bool): Whether to rescale results.
points (list[tensor]): Point clouds of multiple samples.
feats_dict (dict): Features from FPN or backbone..
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes meta information of data.
Returns:
dict: Bbox results of one frame.
list[:obj:`InstanceData`]: List of processed predictions. Each
InstanceData contains 3d Bounding boxes and corresponding
scores and labels.
"""
sample_mod = self.test_cfg.sample_mod
assert sample_mod in ['vote', 'seed', 'random']
result_z = self.primitive_z(feats_dict, sample_mod)
result_z = self.primitive_z(feats_dict)
feats_dict.update(result_z)
result_xy = self.primitive_xy(feats_dict, sample_mod)
result_xy = self.primitive_xy(feats_dict)
feats_dict.update(result_xy)
result_line = self.primitive_line(feats_dict, sample_mod)
result_line = self.primitive_line(feats_dict)
feats_dict.update(result_line)
bbox_preds = self.bbox_head(feats_dict, sample_mod)
bbox_preds = self.bbox_head(feats_dict)
feats_dict.update(bbox_preds)
bbox_list = self.bbox_head.get_bboxes(
points,
feats_dict,
img_metas,
rescale=rescale,
suffix='_optimized')
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results
results_list = self.bbox_head.predict(
points, feats_dict, batch_data_samples, suffix=suffix)
return results_list
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional
import torch
from mmcv.cnn import ConvModule
from mmcv.ops import furthest_point_sample
from mmcv.runner import BaseModule
from mmengine import InstanceData
from torch import nn as nn
from torch.nn import functional as F
from mmdet3d.models.builder import build_loss
from mmdet3d.core import Det3DDataSample
from mmdet3d.models.model_utils import VoteModule
from mmdet3d.ops import build_sa_module
from mmdet3d.registry import MODELS
......@@ -40,24 +43,25 @@ class PrimitiveHead(BaseModule):
"""
def __init__(self,
num_dims,
num_classes,
primitive_mode,
train_cfg=None,
test_cfg=None,
vote_module_cfg=None,
vote_aggregation_cfg=None,
feat_channels=(128, 128),
upper_thresh=100.0,
surface_thresh=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=None,
center_loss=None,
semantic_reg_loss=None,
semantic_cls_loss=None,
init_cfg=None):
num_dims: int,
num_classes: int,
primitive_mode: str,
train_cfg: dict = None,
test_cfg: dict = None,
vote_module_cfg: dict = None,
vote_aggregation_cfg: dict = None,
feat_channels: tuple = (128, 128),
upper_thresh: float = 100.0,
surface_thresh: float = 0.5,
conv_cfg: dict = dict(type='Conv1d'),
norm_cfg: dict = dict(type='BN1d'),
objectness_loss: dict = None,
center_loss: dict = None,
semantic_reg_loss: dict = None,
semantic_cls_loss: dict = None,
init_cfg: dict = None):
super(PrimitiveHead, self).__init__(init_cfg=init_cfg)
# bounding boxes centers, face centers and edge centers
assert primitive_mode in ['z', 'xy', 'line']
# The dimension of primitive semantic information.
self.num_dims = num_dims
......@@ -70,10 +74,10 @@ class PrimitiveHead(BaseModule):
self.upper_thresh = upper_thresh
self.surface_thresh = surface_thresh
self.objectness_loss = build_loss(objectness_loss)
self.center_loss = build_loss(center_loss)
self.semantic_reg_loss = build_loss(semantic_reg_loss)
self.semantic_cls_loss = build_loss(semantic_cls_loss)
self.loss_objectness = MODELS.build(objectness_loss)
self.loss_center = MODELS.build(center_loss)
self.loss_semantic_reg = MODELS.build(semantic_reg_loss)
self.loss_semantic_cls = MODELS.build(semantic_cls_loss)
assert vote_aggregation_cfg['mlp_channels'][0] == vote_module_cfg[
'in_channels']
......@@ -114,18 +118,26 @@ class PrimitiveHead(BaseModule):
self.conv_pred.add_module('conv_out',
nn.Conv1d(prev_channel, conv_out_channel, 1))
def forward(self, feats_dict, sample_mod):
@property
def sample_mode(self):
if self.training:
sample_mode = self.train_cfg.sample_mode
else:
sample_mode = self.test_cfg.sample_mode
assert sample_mode in ['vote', 'seed', 'random']
return sample_mode
def forward(self, feats_dict):
"""Forward pass.
Args:
feats_dict (dict): Feature dict from backbone.
sample_mod (str): Sample mode for vote aggregation layer.
valid modes are "vote", "seed" and "random".
Returns:
dict: Predictions of primitive head.
"""
assert sample_mod in ['vote', 'seed', 'random']
sample_mode = self.sample_mode
seed_points = feats_dict['fp_xyz_net0'][-1]
seed_features = feats_dict['hd_feature']
......@@ -143,14 +155,14 @@ class PrimitiveHead(BaseModule):
results['vote_features_' + self.primitive_mode] = vote_features
# 2. aggregate vote_points
if sample_mod == 'vote':
if sample_mode == 'vote':
# use fps in vote_aggregation
sample_indices = None
elif sample_mod == 'seed':
elif sample_mode == 'seed':
# FPS on seed and choose the votes corresponding to the seeds
sample_indices = furthest_point_sample(seed_points,
self.num_proposal)
elif sample_mod == 'random':
elif sample_mode == 'random':
# Random sampling from the votes
batch_size, num_seed = seed_points.shape[:2]
sample_indices = torch.randint(
......@@ -185,63 +197,103 @@ class PrimitiveHead(BaseModule):
results['pred_' + self.primitive_mode + '_center'] = center
return results
def loss(self,
bbox_preds,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
img_metas=None,
gt_bboxes_ignore=None):
def loss(self, points: List[torch.Tensor], feats_dict: Dict[str,
torch.Tensor],
batch_data_samples: List[Det3DDataSample], **kwargs) -> dict:
"""
Args:
points (list[tensor]): Points cloud of multiple samples.
feats_dict (dict): Predictions from backbone or FPN.
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each sample and
corresponding annotations.
Returns:
dict: A dictionary of loss components.
"""
preds = self(feats_dict)
feats_dict.update(preds)
batch_gt_instance_3d = []
batch_gt_instances_ignore = []
batch_input_metas = []
batch_pts_semantic_mask = []
batch_pts_instance_mask = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
batch_gt_instance_3d.append(data_sample.gt_instances_3d)
batch_gt_instances_ignore.append(
data_sample.get('ignored_instances', None))
batch_pts_semantic_mask.append(
data_sample.gt_pts_seg.get('pts_semantic_mask', None))
batch_pts_instance_mask.append(
data_sample.gt_pts_seg.get('pts_instance_mask', None))
loss_inputs = (points, feats_dict, batch_gt_instance_3d)
losses = self.loss_by_feat(
*loss_inputs,
batch_pts_semantic_mask=batch_pts_semantic_mask,
batch_pts_instance_mask=batch_pts_instance_mask,
batch_gt_instances_ignore=batch_gt_instances_ignore,
)
return losses
def loss_by_feat(
self,
points: List[torch.Tensor],
feats_dict: dict,
batch_gt_instances_3d: List[InstanceData],
batch_pts_semantic_mask: Optional[List[torch.Tensor]] = None,
batch_pts_instance_mask: Optional[List[torch.Tensor]] = None,
**kwargs):
"""Compute loss.
Args:
bbox_preds (dict): Predictions from forward of primitive head.
points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
bboxes of each sample.
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
pts_semantic_mask (list[torch.Tensor]): Point-wise
semantic mask.
pts_instance_mask (list[torch.Tensor]): Point-wise
instance mask.
img_metas (list[dict]): Contain pcd and img's meta info.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding.
feats_dict (dict): Predictions of previous modules.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes`` and ``labels``
attributes.
batch_pts_semantic_mask (list[tensor]): Semantic mask
of points cloud. Defaults to None.
batch_pts_semantic_mask (list[tensor]): Instance mask
of points cloud. Defaults to None.
batch_input_metas (list[dict]): Contain pcd and img's meta info.
ret_target (bool): Return targets or not. Defaults to False.
Returns:
dict: Losses of Primitive Head.
"""
targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask,
bbox_preds)
targets = self.get_targets(points, feats_dict, batch_gt_instances_3d,
batch_pts_semantic_mask,
batch_pts_instance_mask)
(point_mask, point_offset, gt_primitive_center, gt_primitive_semantic,
gt_sem_cls_label, gt_primitive_mask) = targets
losses = {}
# Compute the loss of primitive existence flag
pred_flag = bbox_preds['pred_flag_' + self.primitive_mode]
flag_loss = self.objectness_loss(pred_flag, gt_primitive_mask.long())
pred_flag = feats_dict['pred_flag_' + self.primitive_mode]
flag_loss = self.loss_objectness(pred_flag, gt_primitive_mask.long())
losses['flag_loss_' + self.primitive_mode] = flag_loss
# calculate vote loss
vote_loss = self.vote_module.get_loss(
bbox_preds['seed_points'],
bbox_preds['vote_' + self.primitive_mode],
bbox_preds['seed_indices'], point_mask, point_offset)
feats_dict['seed_points'],
feats_dict['vote_' + self.primitive_mode],
feats_dict['seed_indices'], point_mask, point_offset)
losses['vote_loss_' + self.primitive_mode] = vote_loss
num_proposal = bbox_preds['aggregated_points_' +
num_proposal = feats_dict['aggregated_points_' +
self.primitive_mode].shape[1]
primitive_center = bbox_preds['center_' + self.primitive_mode]
primitive_center = feats_dict['center_' + self.primitive_mode]
if self.primitive_mode != 'line':
primitive_semantic = bbox_preds['size_residuals_' +
primitive_semantic = feats_dict['size_residuals_' +
self.primitive_mode].contiguous()
else:
primitive_semantic = None
semancitc_scores = bbox_preds['sem_cls_scores_' +
semancitc_scores = feats_dict['sem_cls_scores_' +
self.primitive_mode].transpose(2, 1)
gt_primitive_mask = gt_primitive_mask / \
......@@ -256,44 +308,61 @@ class PrimitiveHead(BaseModule):
return losses
def get_targets(self,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
bbox_preds=None):
def get_targets(
self,
points,
bbox_preds: Optional[dict] = None,
batch_gt_instances_3d: List[InstanceData] = None,
batch_pts_semantic_mask: List[torch.Tensor] = None,
batch_pts_instance_mask: List[torch.Tensor] = None,
):
"""Generate targets of primitive head.
Args:
points (list[torch.Tensor]): Points of each batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): Labels of each batch.
pts_semantic_mask (list[torch.Tensor]): Point-wise semantic
label of each batch.
pts_instance_mask (list[torch.Tensor]): Point-wise instance
label of each batch.
bbox_preds (dict): Predictions from forward of primitive head.
bbox_preds (torch.Tensor): Bounding box predictions of
primitive head.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
batch_pts_semantic_mask (list[tensor]): Semantic gt mask for
multiple images.
batch_pts_instance_mask (list[tensor]): Instance gt mask for
multiple images.
Returns:
tuple[torch.Tensor]: Targets of primitive head.
"""
for index in range(len(gt_labels_3d)):
if len(gt_labels_3d[index]) == 0:
fake_box = gt_bboxes_3d[index].tensor.new_zeros(
1, gt_bboxes_3d[index].tensor.shape[-1])
gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box)
gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1)
if pts_semantic_mask is None:
pts_semantic_mask = [None for i in range(len(gt_labels_3d))]
pts_instance_mask = [None for i in range(len(gt_labels_3d))]
batch_gt_labels_3d = [
gt_instances_3d.labels_3d
for gt_instances_3d in batch_gt_instances_3d
]
batch_gt_bboxes_3d = [
gt_instances_3d.bboxes_3d
for gt_instances_3d in batch_gt_instances_3d
]
for index in range(len(batch_gt_labels_3d)):
if len(batch_gt_labels_3d[index]) == 0:
fake_box = batch_gt_bboxes_3d[index].tensor.new_zeros(
1, batch_gt_bboxes_3d[index].tensor.shape[-1])
batch_gt_bboxes_3d[index] = batch_gt_bboxes_3d[index].new_box(
fake_box)
batch_gt_labels_3d[index] = batch_gt_labels_3d[
index].new_zeros(1)
if batch_pts_semantic_mask is None:
batch_pts_semantic_mask = [
None for _ in range(len(batch_gt_labels_3d))
]
batch_pts_instance_mask = [
None for _ in range(len(batch_gt_labels_3d))
]
(point_mask, point_sem,
point_offset) = multi_apply(self.get_targets_single, points,
gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask)
batch_gt_bboxes_3d, batch_gt_labels_3d,
batch_pts_semantic_mask,
batch_pts_instance_mask)
point_mask = torch.stack(point_mask)
point_sem = torch.stack(point_sem)
......@@ -759,7 +828,7 @@ class PrimitiveHead(BaseModule):
vote_xyz_reshape = primitive_center.view(batch_size * num_proposal, -1,
3)
center_loss = self.center_loss(
center_loss = self.loss_center(
vote_xyz_reshape,
gt_primitive_center,
dst_weight=gt_primitive_mask.view(batch_size * num_proposal, 1))[1]
......@@ -767,7 +836,7 @@ class PrimitiveHead(BaseModule):
if self.primitive_mode != 'line':
size_xyz_reshape = primitive_semantic.view(
batch_size * num_proposal, -1, self.num_dims).contiguous()
size_loss = self.semantic_reg_loss(
size_loss = self.loss_semantic_reg(
size_xyz_reshape,
gt_primitive_semantic,
dst_weight=gt_primitive_mask.view(batch_size * num_proposal,
......@@ -776,7 +845,7 @@ class PrimitiveHead(BaseModule):
size_loss = center_loss.new_tensor(0.0)
# Semantic cls loss
sem_cls_loss = self.semantic_cls_loss(
sem_cls_loss = self.loss_semantic_cls(
semantic_scores, gt_sem_cls_label, weight=gt_primitive_mask)
return center_loss, size_loss, sem_cls_loss
......
import unittest
import torch
from mmengine import DefaultScope
from mmdet3d.registry import MODELS
from tests.utils.model_utils import (_create_detector_inputs,
_get_detector_cfg, _setup_seed)
class TestH3D(unittest.TestCase):
def test_h3dnet(self):
import mmdet3d.models
assert hasattr(mmdet3d.models, 'H3DNet')
DefaultScope.get_instance('test_H3DNet', scope_name='mmdet3d')
_setup_seed(0)
voxel_net_cfg = _get_detector_cfg(
'h3dnet/h3dnet_3x8_scannet-3d-18class.py')
model = MODELS.build(voxel_net_cfg)
num_gt_instance = 5
data = [
_create_detector_inputs(
num_gt_instance=num_gt_instance,
points_feat_dim=4,
bboxes_3d_type='depth',
with_pts_semantic_mask=True,
with_pts_instance_mask=True)
]
if torch.cuda.is_available():
model = model.cuda()
# test simple_test
with torch.no_grad():
batch_inputs, data_samples = model.data_preprocessor(
data, True)
results = model.forward(
batch_inputs, data_samples, mode='predict')
self.assertEqual(len(results), len(data))
self.assertIn('bboxes_3d', results[0].pred_instances_3d)
self.assertIn('scores_3d', results[0].pred_instances_3d)
self.assertIn('labels_3d', results[0].pred_instances_3d)
# save the memory
with torch.no_grad():
losses = model.forward(batch_inputs, data_samples, mode='loss')
self.assertGreater(losses['vote_loss'], 0)
self.assertGreater(losses['objectness_loss'], 0)
self.assertGreater(losses['center_loss'], 0)
......@@ -7,7 +7,8 @@ import numpy as np
import torch
from mmengine import InstanceData
from mmdet3d.core import Det3DDataSample, LiDARInstance3DBoxes, PointData
from mmdet3d.core import (CameraInstance3DBoxes, DepthInstance3DBoxes,
Det3DDataSample, LiDARInstance3DBoxes, PointData)
def _setup_seed(seed):
......@@ -71,19 +72,24 @@ def _get_detector_cfg(fname):
return model
def _create_detector_inputs(
seed=0,
with_points=True,
with_img=False,
num_gt_instance=20,
num_points=10,
points_feat_dim=4,
num_classes=3,
gt_bboxes_dim=7,
with_pts_semantic_mask=False,
with_pts_instance_mask=False,
):
def _create_detector_inputs(seed=0,
with_points=True,
with_img=False,
num_gt_instance=20,
num_points=10,
points_feat_dim=4,
num_classes=3,
gt_bboxes_dim=7,
with_pts_semantic_mask=False,
with_pts_instance_mask=False,
bboxes_3d_type='lidar'):
_setup_seed(seed)
assert bboxes_3d_type in ('lidar', 'depth', 'cam')
bbox_3d_class = {
'lidar': LiDARInstance3DBoxes,
'depth': DepthInstance3DBoxes,
'cam': CameraInstance3DBoxes
}
if with_points:
points = torch.rand([num_points, points_feat_dim])
else:
......@@ -93,12 +99,13 @@ def _create_detector_inputs(
else:
img = None
inputs_dict = dict(img=img, points=points)
gt_instance_3d = InstanceData()
gt_instance_3d.bboxes_3d = LiDARInstance3DBoxes(
gt_instance_3d.bboxes_3d = bbox_3d_class[bboxes_3d_type](
torch.rand([num_gt_instance, gt_bboxes_dim]), box_dim=gt_bboxes_dim)
gt_instance_3d.labels_3d = torch.randint(0, num_classes, [num_gt_instance])
data_sample = Det3DDataSample(
metainfo=dict(box_type_3d=LiDARInstance3DBoxes))
metainfo=dict(box_type_3d=bbox_3d_class[bboxes_3d_type]))
data_sample.gt_instances_3d = gt_instance_3d
data_sample.gt_pts_seg = PointData()
if with_pts_instance_mask:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment