Commit 0e17beab authored by jshilong's avatar jshilong Committed by ChaimZhu
Browse files

[REfactor]Refactor H3D

parent 9ebb75da
...@@ -30,7 +30,7 @@ primitive_z_cfg = dict( ...@@ -30,7 +30,7 @@ primitive_z_cfg = dict(
conv_cfg=dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
objectness_loss=dict( objectness_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
class_weight=[0.4, 0.6], class_weight=[0.4, 0.6],
reduction='mean', reduction='mean',
loss_weight=30.0), loss_weight=30.0),
...@@ -47,14 +47,16 @@ primitive_z_cfg = dict( ...@@ -47,14 +47,16 @@ primitive_z_cfg = dict(
loss_src_weight=0.5, loss_src_weight=0.5,
loss_dst_weight=0.5), loss_dst_weight=0.5),
semantic_cls_loss=dict( semantic_cls_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
train_cfg=dict( train_cfg=dict(
sample_mode='vote',
dist_thresh=0.2, dist_thresh=0.2,
var_thresh=1e-2, var_thresh=1e-2,
lower_thresh=1e-6, lower_thresh=1e-6,
num_point=100, num_point=100,
num_point_line=10, num_point_line=10,
line_thresh=0.2)) line_thresh=0.2),
test_cfg=dict(sample_mode='seed'))
primitive_xy_cfg = dict( primitive_xy_cfg = dict(
type='PrimitiveHead', type='PrimitiveHead',
...@@ -88,7 +90,7 @@ primitive_xy_cfg = dict( ...@@ -88,7 +90,7 @@ primitive_xy_cfg = dict(
conv_cfg=dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
objectness_loss=dict( objectness_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
class_weight=[0.4, 0.6], class_weight=[0.4, 0.6],
reduction='mean', reduction='mean',
loss_weight=30.0), loss_weight=30.0),
...@@ -105,14 +107,16 @@ primitive_xy_cfg = dict( ...@@ -105,14 +107,16 @@ primitive_xy_cfg = dict(
loss_src_weight=0.5, loss_src_weight=0.5,
loss_dst_weight=0.5), loss_dst_weight=0.5),
semantic_cls_loss=dict( semantic_cls_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
train_cfg=dict( train_cfg=dict(
sample_mode='vote',
dist_thresh=0.2, dist_thresh=0.2,
var_thresh=1e-2, var_thresh=1e-2,
lower_thresh=1e-6, lower_thresh=1e-6,
num_point=100, num_point=100,
num_point_line=10, num_point_line=10,
line_thresh=0.2)) line_thresh=0.2),
test_cfg=dict(sample_mode='seed'))
primitive_line_cfg = dict( primitive_line_cfg = dict(
type='PrimitiveHead', type='PrimitiveHead',
...@@ -146,7 +150,7 @@ primitive_line_cfg = dict( ...@@ -146,7 +150,7 @@ primitive_line_cfg = dict(
conv_cfg=dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
objectness_loss=dict( objectness_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
class_weight=[0.4, 0.6], class_weight=[0.4, 0.6],
reduction='mean', reduction='mean',
loss_weight=30.0), loss_weight=30.0),
...@@ -163,17 +167,20 @@ primitive_line_cfg = dict( ...@@ -163,17 +167,20 @@ primitive_line_cfg = dict(
loss_src_weight=1.0, loss_src_weight=1.0,
loss_dst_weight=1.0), loss_dst_weight=1.0),
semantic_cls_loss=dict( semantic_cls_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=2.0), type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=2.0),
train_cfg=dict( train_cfg=dict(
sample_mode='vote',
dist_thresh=0.2, dist_thresh=0.2,
var_thresh=1e-2, var_thresh=1e-2,
lower_thresh=1e-6, lower_thresh=1e-6,
num_point=100, num_point=100,
num_point_line=10, num_point_line=10,
line_thresh=0.2)) line_thresh=0.2),
test_cfg=dict(sample_mode='seed'))
model = dict( model = dict(
type='H3DNet', type='H3DNet',
data_preprocessor=dict(type='Det3DDataPreprocessor'),
backbone=dict( backbone=dict(
type='MultiBackbone', type='MultiBackbone',
num_streams=4, num_streams=4,
...@@ -221,10 +228,8 @@ model = dict( ...@@ -221,10 +228,8 @@ model = dict(
normalize_xyz=True), normalize_xyz=True),
pred_layer_cfg=dict( pred_layer_cfg=dict(
in_channels=128, shared_conv_channels=(128, 128), bias=True), in_channels=128, shared_conv_channels=(128, 128), bias=True),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict( objectness_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
class_weight=[0.2, 0.8], class_weight=[0.2, 0.8],
reduction='sum', reduction='sum',
loss_weight=5.0), loss_weight=5.0),
...@@ -235,15 +240,15 @@ model = dict( ...@@ -235,15 +240,15 @@ model = dict(
loss_src_weight=10.0, loss_src_weight=10.0,
loss_dst_weight=10.0), loss_dst_weight=10.0),
dir_class_loss=dict( dir_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
dir_res_loss=dict( dir_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0), type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=10.0),
size_class_loss=dict( size_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
size_res_loss=dict( size_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0), type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=10.0),
semantic_loss=dict( semantic_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0)), type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0)),
roi_head=dict( roi_head=dict(
type='H3DRoIHead', type='H3DRoIHead',
primitive_list=[primitive_z_cfg, primitive_xy_cfg, primitive_line_cfg], primitive_list=[primitive_z_cfg, primitive_xy_cfg, primitive_line_cfg],
...@@ -267,7 +272,6 @@ model = dict( ...@@ -267,7 +272,6 @@ model = dict(
mlp_channels=[128 + 12, 128, 64, 32], mlp_channels=[128 + 12, 128, 64, 32],
use_xyz=True, use_xyz=True,
normalize_xyz=True), normalize_xyz=True),
feat_channels=(128, 128),
primitive_refine_channels=[128, 128, 128], primitive_refine_channels=[128, 128, 128],
upper_thresh=100.0, upper_thresh=100.0,
surface_thresh=0.5, surface_thresh=0.5,
...@@ -275,7 +279,7 @@ model = dict( ...@@ -275,7 +279,7 @@ model = dict(
conv_cfg=dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
objectness_loss=dict( objectness_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
class_weight=[0.2, 0.8], class_weight=[0.2, 0.8],
reduction='sum', reduction='sum',
loss_weight=5.0), loss_weight=5.0),
...@@ -286,41 +290,47 @@ model = dict( ...@@ -286,41 +290,47 @@ model = dict(
loss_src_weight=10.0, loss_src_weight=10.0,
loss_dst_weight=10.0), loss_dst_weight=10.0),
dir_class_loss=dict( dir_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1), type='mmdet.CrossEntropyLoss',
reduction='sum',
loss_weight=0.1),
dir_res_loss=dict( dir_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0), type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=10.0),
size_class_loss=dict( size_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1), type='mmdet.CrossEntropyLoss',
reduction='sum',
loss_weight=0.1),
size_res_loss=dict( size_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0), type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=10.0),
semantic_loss=dict( semantic_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1), type='mmdet.CrossEntropyLoss',
reduction='sum',
loss_weight=0.1),
cues_objectness_loss=dict( cues_objectness_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
class_weight=[0.3, 0.7], class_weight=[0.3, 0.7],
reduction='mean', reduction='mean',
loss_weight=5.0), loss_weight=5.0),
cues_semantic_loss=dict( cues_semantic_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
class_weight=[0.3, 0.7], class_weight=[0.3, 0.7],
reduction='mean', reduction='mean',
loss_weight=5.0), loss_weight=5.0),
proposal_objectness_loss=dict( proposal_objectness_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
class_weight=[0.2, 0.8], class_weight=[0.2, 0.8],
reduction='none', reduction='none',
loss_weight=5.0), loss_weight=5.0),
primitive_center_loss=dict( primitive_center_loss=dict(
type='MSELoss', reduction='none', loss_weight=1.0))), type='mmdet.MSELoss', reduction='none', loss_weight=1.0))),
# model training and testing settings # model training and testing settings
train_cfg=dict( train_cfg=dict(
rpn=dict( rpn=dict(
pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mod='vote'), pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mode='vote'),
rpn_proposal=dict(use_nms=False), rpn_proposal=dict(use_nms=False),
rcnn=dict( rcnn=dict(
pos_distance_thr=0.3, pos_distance_thr=0.3,
neg_distance_thr=0.6, neg_distance_thr=0.6,
sample_mod='vote', sample_mode='vote',
far_threshold=0.6, far_threshold=0.6,
near_threshold=0.3, near_threshold=0.3,
mask_surface_threshold=0.3, mask_surface_threshold=0.3,
...@@ -329,13 +339,13 @@ model = dict( ...@@ -329,13 +339,13 @@ model = dict(
label_line_threshold=0.3)), label_line_threshold=0.3)),
test_cfg=dict( test_cfg=dict(
rpn=dict( rpn=dict(
sample_mod='seed', sample_mode='seed',
nms_thr=0.25, nms_thr=0.25,
score_thr=0.05, score_thr=0.05,
per_class_proposal=True, per_class_proposal=True,
use_nms=False), use_nms=False),
rcnn=dict( rcnn=dict(
sample_mod='seed', sample_mode='seed',
nms_thr=0.25, nms_thr=0.25,
score_thr=0.05, score_thr=0.05,
per_class_proposal=True))) per_class_proposal=True)))
_base_ = [
'../_base_/datasets/scannet-3d-18class.py', '../_base_/models/h3dnet.py',
'../_base_/schedules/schedule_3x.py', '../_base_/default_runtime.py'
]
# model settings
model = dict(
rpn_head=dict(
num_classes=18,
bbox_coder=dict(
type='PartialBinBasedBBoxCoder',
num_sizes=18,
num_dir_bins=24,
with_rot=False,
mean_sizes=[[0.76966727, 0.8116021, 0.92573744],
[1.876858, 1.8425595, 1.1931566],
[0.61328, 0.6148609, 0.7182701],
[1.3955007, 1.5121545, 0.83443564],
[0.97949594, 1.0675149, 0.6329687],
[0.531663, 0.5955577, 1.7500148],
[0.9624706, 0.72462326, 1.1481868],
[0.83221924, 1.0490936, 1.6875663],
[0.21132214, 0.4206159, 0.5372846],
[1.4440073, 1.8970833, 0.26985747],
[1.0294262, 1.4040797, 0.87554324],
[1.3766412, 0.65521795, 1.6813129],
[0.6650819, 0.71111923, 1.298853],
[0.41999173, 0.37906948, 1.7513971],
[0.59359556, 0.5912492, 0.73919016],
[0.50867593, 0.50656086, 0.30136237],
[1.1511526, 1.0546296, 0.49706793],
[0.47535285, 0.49249494, 0.5802117]])),
roi_head=dict(
bbox_head=dict(
num_classes=18,
bbox_coder=dict(
type='PartialBinBasedBBoxCoder',
num_sizes=18,
num_dir_bins=24,
with_rot=False,
mean_sizes=[[0.76966727, 0.8116021, 0.92573744],
[1.876858, 1.8425595, 1.1931566],
[0.61328, 0.6148609, 0.7182701],
[1.3955007, 1.5121545, 0.83443564],
[0.97949594, 1.0675149, 0.6329687],
[0.531663, 0.5955577, 1.7500148],
[0.9624706, 0.72462326, 1.1481868],
[0.83221924, 1.0490936, 1.6875663],
[0.21132214, 0.4206159, 0.5372846],
[1.4440073, 1.8970833, 0.26985747],
[1.0294262, 1.4040797, 0.87554324],
[1.3766412, 0.65521795, 1.6813129],
[0.6650819, 0.71111923, 1.298853],
[0.41999173, 0.37906948, 1.7513971],
[0.59359556, 0.5912492, 0.73919016],
[0.50867593, 0.50656086, 0.30136237],
[1.1511526, 1.0546296, 0.49706793],
[0.47535285, 0.49249494, 0.5802117]]))))
train_dataloader = dict(
batch_size=3,
num_workers=2,
)
# yapf:disable
default_hooks = dict(
logger=dict(type='LoggerHook', interval=30)
)
# yapf:enable
...@@ -57,8 +57,13 @@ model = dict( ...@@ -57,8 +57,13 @@ model = dict(
[1.1511526, 1.0546296, 0.49706793], [1.1511526, 1.0546296, 0.49706793],
[0.47535285, 0.49249494, 0.5802117]])))) [0.47535285, 0.49249494, 0.5802117]]))))
data = dict(samples_per_gpu=3, workers_per_gpu=2) train_dataloader = dict(
batch_size=3,
num_workers=2,
)
# yapf:disable # yapf:disable
log_config = dict(interval=30) default_hooks = dict(
logger=dict(type='LoggerHook', interval=30)
)
# yapf:enable # yapf:enable
...@@ -5,11 +5,9 @@ from typing import Callable, List, Optional, Union ...@@ -5,11 +5,9 @@ from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
from mmdet3d.core import show_result
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.registry import DATASETS from mmdet3d.registry import DATASETS
from .det3d_dataset import Det3DDataset from .det3d_dataset import Det3DDataset
from .pipelines import Compose
from .seg3d_dataset import Seg3DDataset from .seg3d_dataset import Seg3DDataset
...@@ -151,46 +149,6 @@ class ScanNetDataset(Det3DDataset): ...@@ -151,46 +149,6 @@ class ScanNetDataset(Det3DDataset):
return ann_info return ann_info
def _build_default_pipeline(self):
"""Build the default pipeline for this dataset."""
pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='DefaultFormatBundle3D',
class_names=self.CLASSES,
with_label=False),
dict(type='Collect3D', keys=['points'])
]
return Compose(pipeline)
def show(self, results, out_dir, show=True, pipeline=None):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
show (bool): Visualize the results online.
pipeline (list[dict], optional): raw data loading for showing.
Default: None.
"""
assert out_dir is not None, 'Expect out_dir, got none.'
pipeline = self._get_pipeline(pipeline)
for i, result in enumerate(results):
data_info = self.get_data_info[i]
pts_path = data_info['lidar_points']['lidar_path']
file_name = osp.split(pts_path)[-1].split('.')[0]
points = self._extract_data(i, pipeline, 'points').numpy()
gt_bboxes = self.get_ann_info(i)['gt_bboxes_3d'].tensor.numpy()
pred_bboxes = result['boxes_3d'].tensor.numpy()
show_result(points, gt_bboxes, pred_bboxes, out_dir, file_name,
show)
@DATASETS.register_module() @DATASETS.register_module()
class ScanNetSegDataset(Seg3DDataset): class ScanNetSegDataset(Seg3DDataset):
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from mmcv.ops import furthest_point_sample from mmcv.ops import furthest_point_sample
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmengine import ConfigDict, InstanceData from mmengine import ConfigDict, InstanceData
from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core.post_processing import aligned_3d_nms from mmdet3d.core.post_processing import aligned_3d_nms
...@@ -161,7 +162,7 @@ class VoteHead(BaseModule): ...@@ -161,7 +162,7 @@ class VoteHead(BaseModule):
points: List[torch.Tensor], points: List[torch.Tensor],
feats_dict: Dict[str, torch.Tensor], feats_dict: Dict[str, torch.Tensor],
batch_data_samples: List[Det3DDataSample], batch_data_samples: List[Det3DDataSample],
rescale=True, use_nms: bool = True,
**kwargs) -> List[InstanceData]: **kwargs) -> List[InstanceData]:
""" """
Args: Args:
...@@ -169,8 +170,8 @@ class VoteHead(BaseModule): ...@@ -169,8 +170,8 @@ class VoteHead(BaseModule):
feats_dict (dict): Features from FPN or backbone.. feats_dict (dict): Features from FPN or backbone..
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes meta information of data. Samples. It usually includes meta information of data.
rescale (bool): Whether rescale the resutls to use_nms (bool): Whether do the nms for predictions.
the original scale. Defaults to True.
Returns: Returns:
list[:obj:`InstanceData`]: List of processed predictions. Each list[:obj:`InstanceData`]: List of processed predictions. Each
...@@ -178,6 +179,9 @@ class VoteHead(BaseModule): ...@@ -178,6 +179,9 @@ class VoteHead(BaseModule):
scores and labels. scores and labels.
""" """
preds_dict = self(feats_dict) preds_dict = self(feats_dict)
# `preds_dict` can be used in H3DNET
feats_dict.update(preds_dict)
batch_size = len(batch_data_samples) batch_size = len(batch_data_samples)
batch_input_metas = [] batch_input_metas = []
for batch_index in range(batch_size): for batch_index in range(batch_size):
...@@ -185,12 +189,73 @@ class VoteHead(BaseModule): ...@@ -185,12 +189,73 @@ class VoteHead(BaseModule):
batch_input_metas.append(metainfo) batch_input_metas.append(metainfo)
results_list = self.predict_by_feat( results_list = self.predict_by_feat(
points, preds_dict, batch_input_metas, rescale=rescale, **kwargs) points, preds_dict, batch_input_metas, use_nms=use_nms, **kwargs)
return results_list return results_list
def loss(self, points: List[torch.Tensor], feats_dict: Dict[str, def loss_and_predict(self,
torch.Tensor], points: List[torch.Tensor],
batch_data_samples: List[Det3DDataSample], **kwargs) -> dict: feats_dict: Dict[str, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
ret_target: bool = False,
proposal_cfg: dict = None,
**kwargs) -> Tuple:
"""
Args:
points (list[tensor]): Points cloud of multiple samples.
feats_dict (dict): Predictions from backbone or FPN.
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each sample and
corresponding annotations.
ret_target (bool): Whether return the assigned target.
Defaults to False.
proposal_cfg (dict): Configure for proposal process.
Defaults to True.
Returns:
tuple: Contains loss and predictions after post-process.
"""
preds_dict = self.forward(feats_dict)
feats_dict.update(preds_dict)
batch_gt_instance_3d = []
batch_gt_instances_ignore = []
batch_input_metas = []
batch_pts_semantic_mask = []
batch_pts_instance_mask = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
batch_gt_instance_3d.append(data_sample.gt_instances_3d)
batch_gt_instances_ignore.append(
data_sample.get('ignored_instances', None))
batch_pts_semantic_mask.append(
data_sample.gt_pts_seg.get('pts_semantic_mask', None))
batch_pts_instance_mask.append(
data_sample.gt_pts_seg.get('pts_instance_mask', None))
loss_inputs = (points, preds_dict, batch_gt_instance_3d)
losses = self.loss_by_feat(
*loss_inputs,
batch_pts_semantic_mask=batch_pts_semantic_mask,
batch_pts_instance_mask=batch_pts_instance_mask,
batch_input_metas=batch_input_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore,
ret_target=ret_target,
**kwargs)
results_list = self.predict_by_feat(
points,
preds_dict,
batch_input_metas,
use_nms=proposal_cfg.use_nms,
**kwargs)
return losses, results_list
def loss(self,
points: List[torch.Tensor],
feats_dict: Dict[str, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
ret_target: bool = False,
**kwargs) -> dict:
""" """
Args: Args:
points (list[tensor]): Points cloud of multiple samples. points (list[tensor]): Points cloud of multiple samples.
...@@ -198,6 +263,8 @@ class VoteHead(BaseModule): ...@@ -198,6 +263,8 @@ class VoteHead(BaseModule):
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each sample and contains the meta information of each sample and
corresponding annotations. corresponding annotations.
ret_target (bool): Whether return the assigned target.
Defaults to False.
Returns: Returns:
dict: A dictionary of loss components. dict: A dictionary of loss components.
...@@ -224,7 +291,9 @@ class VoteHead(BaseModule): ...@@ -224,7 +291,9 @@ class VoteHead(BaseModule):
batch_pts_semantic_mask=batch_pts_semantic_mask, batch_pts_semantic_mask=batch_pts_semantic_mask,
batch_pts_instance_mask=batch_pts_instance_mask, batch_pts_instance_mask=batch_pts_instance_mask,
batch_input_metas=batch_input_metas, batch_input_metas=batch_input_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore) batch_gt_instances_ignore=batch_gt_instances_ignore,
ret_target=ret_target,
**kwargs)
return losses return losses
def forward(self, feat_dict: dict) -> dict: def forward(self, feat_dict: dict) -> dict:
...@@ -330,7 +399,7 @@ class VoteHead(BaseModule): ...@@ -330,7 +399,7 @@ class VoteHead(BaseModule):
batch_pts_semantic_mask (list[tensor]): Instance mask batch_pts_semantic_mask (list[tensor]): Instance mask
of points cloud. Defaults to None. of points cloud. Defaults to None.
batch_input_metas (list[dict]): Contain pcd and img's meta info. batch_input_metas (list[dict]): Contain pcd and img's meta info.
ret_target (bool): Return targets or not. ret_target (bool): Return targets or not. Defaults to False.
Returns: Returns:
dict: Losses of Votenet. dict: Losses of Votenet.
...@@ -671,9 +740,10 @@ class VoteHead(BaseModule): ...@@ -671,9 +740,10 @@ class VoteHead(BaseModule):
while using vote head in rpn stage. while using vote head in rpn stage.
Returns: Returns:
list[:obj:`InstanceData`]: List of processed predictions. Each list[:obj:`InstanceData`] or Tensor: Return list of processed
InstanceData cantains 3d Bounding boxes and corresponding predictions when `use_nms` is True. Each InstanceData cantains
scores and labels. 3d Bounding boxes and corresponding scores and labels.
Return raw bboxes when `use_nms` is False.
""" """
# decode boxes # decode boxes
stack_points = torch.stack(points) stack_points = torch.stack(points)
...@@ -683,9 +753,9 @@ class VoteHead(BaseModule): ...@@ -683,9 +753,9 @@ class VoteHead(BaseModule):
batch_size = bbox3d.shape[0] batch_size = bbox3d.shape[0]
results_list = list() results_list = list()
for b in range(batch_size): if use_nms:
temp_results = InstanceData() for b in range(batch_size):
if use_nms: temp_results = InstanceData()
bbox_selected, score_selected, labels = \ bbox_selected, score_selected, labels = \
self.multiclass_nms_single(obj_scores[b], self.multiclass_nms_single(obj_scores[b],
sem_scores[b], sem_scores[b],
...@@ -700,20 +770,15 @@ class VoteHead(BaseModule): ...@@ -700,20 +770,15 @@ class VoteHead(BaseModule):
temp_results.scores_3d = score_selected temp_results.scores_3d = score_selected
temp_results.labels_3d = labels temp_results.labels_3d = labels
results_list.append(temp_results) results_list.append(temp_results)
else:
bbox = batch_input_metas[b]['box_type_3d'](
bbox_selected,
box_dim=bbox_selected.shape[-1],
with_yaw=self.bbox_coder.with_rot)
temp_results.bboxes_3d = bbox
temp_results.obj_scores_3d = obj_scores[b]
temp_results.sem_scores_3d = obj_scores[b]
results_list.append(temp_results)
return results_list return results_list
else:
# TODO unify it when refactor the Augtest
return bbox3d
def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points, def multiclass_nms_single(self, obj_scores: Tensor, sem_scores: Tensor,
input_meta): bbox: Tensor, points: Tensor,
input_meta: dict) -> Tuple:
"""Multi-class nms in single batch. """Multi-class nms in single batch.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch import torch
from torch import Tensor
from mmdet3d.core import merge_aug_bboxes_3d
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from ...core import Det3DDataSample
from .two_stage import TwoStage3DDetector from .two_stage import TwoStage3DDetector
...@@ -11,17 +14,33 @@ class H3DNet(TwoStage3DDetector): ...@@ -11,17 +14,33 @@ class H3DNet(TwoStage3DDetector):
r"""H3DNet model. r"""H3DNet model.
Please refer to the `paper <https://arxiv.org/abs/2006.05682>`_ Please refer to the `paper <https://arxiv.org/abs/2006.05682>`_
Args:
backbone (dict): Config dict of detector's backbone.
neck (dict, optional): Config dict of neck. Defaults to None.
rpn_head (dict, optional): Config dict of rpn head. Defaults to None.
roi_head (dict, optional): Config dict of roi head. Defaults to None.
train_cfg (dict, optional): Config dict of training hyper-parameters.
Defaults to None.
test_cfg (dict, optional): Config dict of test hyper-parameters.
Defaults to None.
init_cfg (dict, optional): the config to control the
initialization. Default to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`BaseDataPreprocessor`. it usually includes,
``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
""" """
def __init__(self, def __init__(self,
backbone, backbone: dict,
neck=None, neck: Optional[dict] = None,
rpn_head=None, rpn_head: Optional[dict] = None,
roi_head=None, roi_head: Optional[dict] = None,
train_cfg=None, train_cfg: Optional[dict] = None,
test_cfg=None, test_cfg: Optional[dict] = None,
pretrained=None, init_cfg: Optional[dict] = None,
init_cfg=None): data_preprocessor: Optional[dict] = None,
**kwargs) -> None:
super(H3DNet, self).__init__( super(H3DNet, self).__init__(
backbone=backbone, backbone=backbone,
neck=neck, neck=neck,
...@@ -29,148 +48,110 @@ class H3DNet(TwoStage3DDetector): ...@@ -29,148 +48,110 @@ class H3DNet(TwoStage3DDetector):
roi_head=roi_head, roi_head=roi_head,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
pretrained=pretrained, init_cfg=init_cfg,
init_cfg=init_cfg) data_preprocessor=data_preprocessor,
**kwargs)
def forward_train(self,
points, def extract_feat(self, batch_inputs_dict: dict) -> None:
img_metas, """Directly extract features from the backbone+neck.
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
gt_bboxes_ignore=None):
"""Forward of training.
Args: Args:
points (list[torch.Tensor]): Points of each batch.
img_metas (list): Image metas. batch_inputs_dict (dict): The model input dict which include
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch. 'points'.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
pts_semantic_mask (list[torch.Tensor]): point-wise semantic - points (list[torch.Tensor]): Point cloud of each sample.
label of each batch.
pts_instance_mask (list[torch.Tensor]): point-wise instance
label of each batch.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding.
Returns: Returns:
dict: Losses. dict: Dict of feature.
""" """
points_cat = torch.stack(points) stack_points = torch.stack(batch_inputs_dict['points'])
x = self.backbone(stack_points)
if self.with_neck:
x = self.neck(x)
return x
def loss(self, batch_inputs_dict: Dict[str, Union[List, Tensor]],
batch_data_samples: List[Det3DDataSample], **kwargs) -> dict:
"""
Args:
batch_inputs_dict (dict): The model input dict which include
'points' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
feats_dict = self.extract_feat(batch_inputs_dict)
feats_dict = self.extract_feat(points_cat)
feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]] feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]]
feats_dict['fp_features'] = [feats_dict['hd_feature']] feats_dict['fp_features'] = [feats_dict['hd_feature']]
feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]] feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]]
losses = dict() losses = dict()
if self.with_rpn: if self.with_rpn:
rpn_outs = self.rpn_head(feats_dict, self.train_cfg.rpn.sample_mod)
feats_dict.update(rpn_outs)
rpn_loss_inputs = (points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, img_metas)
rpn_losses = self.rpn_head.loss(
rpn_outs,
*rpn_loss_inputs,
gt_bboxes_ignore=gt_bboxes_ignore,
ret_target=True)
feats_dict['targets'] = rpn_losses.pop('targets')
losses.update(rpn_losses)
# Generate rpn proposals
proposal_cfg = self.train_cfg.get('rpn_proposal', proposal_cfg = self.train_cfg.get('rpn_proposal',
self.test_cfg.rpn) self.test_cfg.rpn)
proposal_inputs = (points, rpn_outs, img_metas) # note, the feats_dict would be added new key & value in rpn_head
proposal_list = self.rpn_head.get_bboxes( rpn_losses, rpn_proposals = self.rpn_head.loss_and_predict(
*proposal_inputs, use_nms=proposal_cfg.use_nms) batch_inputs_dict['points'],
feats_dict['proposal_list'] = proposal_list feats_dict,
batch_data_samples,
ret_target=True,
proposal_cfg=proposal_cfg)
feats_dict['targets'] = rpn_losses.pop('targets')
losses.update(rpn_losses)
feats_dict['rpn_proposals'] = rpn_proposals
else: else:
raise NotImplementedError raise NotImplementedError
roi_losses = self.roi_head.forward_train(feats_dict, img_metas, points, roi_losses = self.roi_head.loss(batch_inputs_dict['points'],
gt_bboxes_3d, gt_labels_3d, feats_dict, batch_data_samples,
pts_semantic_mask, **kwargs)
pts_instance_mask,
gt_bboxes_ignore)
losses.update(roi_losses) losses.update(roi_losses)
return losses return losses
def simple_test(self, points, img_metas, imgs=None, rescale=False): def predict(
"""Forward of testing. self, batch_input_dict: Dict,
batch_data_samples: List[Det3DDataSample]
) -> List[Det3DDataSample]:
"""Get model predictions.
Args: Args:
points (list[torch.Tensor]): Points of each sample. points (list[torch.Tensor]): Points of each sample.
img_metas (list): Image metas. batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
rescale (bool): Whether to rescale results. contains the meta information of each sample and
corresponding annotations.
Returns: Returns:
list: Predicted 3d boxes. list: Predicted 3d boxes.
""" """
points_cat = torch.stack(points)
feats_dict = self.extract_feat(points_cat) feats_dict = self.extract_feat(batch_input_dict)
feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]] feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]]
feats_dict['fp_features'] = [feats_dict['hd_feature']] feats_dict['fp_features'] = [feats_dict['hd_feature']]
feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]] feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]]
if self.with_rpn: if self.with_rpn:
proposal_cfg = self.test_cfg.rpn proposal_cfg = self.test_cfg.rpn
rpn_outs = self.rpn_head(feats_dict, proposal_cfg.sample_mod) rpn_proposals = self.rpn_head.predict(
feats_dict.update(rpn_outs) batch_input_dict['points'],
# Generate rpn proposals feats_dict,
proposal_list = self.rpn_head.get_bboxes( batch_data_samples,
points, rpn_outs, img_metas, use_nms=proposal_cfg.use_nms) use_nms=proposal_cfg.use_nms)
feats_dict['proposal_list'] = proposal_list feats_dict['rpn_proposals'] = rpn_proposals
else: else:
raise NotImplementedError raise NotImplementedError
return self.roi_head.simple_test( results_list = self.roi_head.predict(
feats_dict, img_metas, points_cat, rescale=rescale) batch_input_dict['points'],
feats_dict,
def aug_test(self, points, img_metas, imgs=None, rescale=False): batch_data_samples,
"""Test with augmentation.""" suffix='_optimized')
points_cat = [torch.stack(pts) for pts in points] return self.convert_to_datasample(results_list)
feats_dict = self.extract_feats(points_cat, img_metas)
for feat_dict in feats_dict:
feat_dict['fp_xyz'] = [feat_dict['fp_xyz_net0'][-1]]
feat_dict['fp_features'] = [feat_dict['hd_feature']]
feat_dict['fp_indices'] = [feat_dict['fp_indices_net0'][-1]]
# only support aug_test for one sample
aug_bboxes = []
for feat_dict, pts_cat, img_meta in zip(feats_dict, points_cat,
img_metas):
if self.with_rpn:
proposal_cfg = self.test_cfg.rpn
rpn_outs = self.rpn_head(feat_dict, proposal_cfg.sample_mod)
feat_dict.update(rpn_outs)
# Generate rpn proposals
proposal_list = self.rpn_head.get_bboxes(
points, rpn_outs, img_metas, use_nms=proposal_cfg.use_nms)
feat_dict['proposal_list'] = proposal_list
else:
raise NotImplementedError
bbox_results = self.roi_head.simple_test(
feat_dict,
self.test_cfg.rcnn.sample_mod,
img_meta,
pts_cat,
rescale=rescale)
aug_bboxes.append(bbox_results)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.bbox_head.test_cfg)
return [merged_bboxes]
def extract_feats(self, points, img_metas):
"""Extract features of multiple samples."""
return [
self.extract_feat(pts, img_meta)
for pts, img_meta in zip(points, img_metas)
]
...@@ -56,12 +56,12 @@ class PointRCNN(TwoStage3DDetector): ...@@ -56,12 +56,12 @@ class PointRCNN(TwoStage3DDetector):
x = self.neck(x) x = self.neck(x)
return x return x
def forward_train(self, points, img_metas, gt_bboxes_3d, gt_labels_3d): def forward_train(self, points, input_metas, gt_bboxes_3d, gt_labels_3d):
"""Forward of training. """Forward of training.
Args: Args:
points (list[torch.Tensor]): Points of each batch. points (list[torch.Tensor]): Points of each batch.
img_metas (list[dict]): Meta information of each sample. input_metas (list[dict]): Meta information of each sample.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch. gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch. gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
...@@ -69,8 +69,8 @@ class PointRCNN(TwoStage3DDetector): ...@@ -69,8 +69,8 @@ class PointRCNN(TwoStage3DDetector):
dict: Losses. dict: Losses.
""" """
losses = dict() losses = dict()
points_cat = torch.stack(points) stack_points = torch.stack(points)
x = self.extract_feat(points_cat) x = self.extract_feat(stack_points)
# features for rcnn # features for rcnn
backbone_feats = x['fp_features'].clone() backbone_feats = x['fp_features'].clone()
...@@ -85,11 +85,11 @@ class PointRCNN(TwoStage3DDetector): ...@@ -85,11 +85,11 @@ class PointRCNN(TwoStage3DDetector):
points=points, points=points,
gt_bboxes_3d=gt_bboxes_3d, gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=gt_labels_3d, gt_labels_3d=gt_labels_3d,
img_metas=img_metas) input_metas=input_metas)
losses.update(rpn_loss) losses.update(rpn_loss)
bbox_list = self.rpn_head.get_bboxes(points_cat, bbox_preds, cls_preds, bbox_list = self.rpn_head.get_bboxes(stack_points, bbox_preds,
img_metas) cls_preds, input_metas)
proposal_list = [ proposal_list = [
dict( dict(
boxes_3d=bboxes, boxes_3d=bboxes,
...@@ -100,7 +100,7 @@ class PointRCNN(TwoStage3DDetector): ...@@ -100,7 +100,7 @@ class PointRCNN(TwoStage3DDetector):
] ]
rcnn_feats.update({'points_cls_preds': cls_preds}) rcnn_feats.update({'points_cls_preds': cls_preds})
roi_losses = self.roi_head.forward_train(rcnn_feats, img_metas, roi_losses = self.roi_head.forward_train(rcnn_feats, input_metas,
proposal_list, gt_bboxes_3d, proposal_list, gt_bboxes_3d,
gt_labels_3d) gt_labels_3d)
losses.update(roi_losses) losses.update(roi_losses)
...@@ -121,9 +121,9 @@ class PointRCNN(TwoStage3DDetector): ...@@ -121,9 +121,9 @@ class PointRCNN(TwoStage3DDetector):
Returns: Returns:
list: Predicted 3d boxes. list: Predicted 3d boxes.
""" """
points_cat = torch.stack(points) stack_points = torch.stack(points)
x = self.extract_feat(points_cat) x = self.extract_feat(stack_points)
# features for rcnn # features for rcnn
backbone_feats = x['fp_features'].clone() backbone_feats = x['fp_features'].clone()
backbone_xyz = x['fp_xyz'].clone() backbone_xyz = x['fp_xyz'].clone()
...@@ -132,7 +132,7 @@ class PointRCNN(TwoStage3DDetector): ...@@ -132,7 +132,7 @@ class PointRCNN(TwoStage3DDetector):
rcnn_feats.update({'points_cls_preds': cls_preds}) rcnn_feats.update({'points_cls_preds': cls_preds})
bbox_list = self.rpn_head.get_bboxes( bbox_list = self.rpn_head.get_bboxes(
points_cat, bbox_preds, cls_preds, img_metas, rescale=rescale) stack_points, bbox_preds, cls_preds, img_metas, rescale=rescale)
proposal_list = [ proposal_list = [
dict( dict(
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdet3d.core.bbox import bbox3d2result from typing import Dict, List
from mmengine import InstanceData
from torch import Tensor
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from ...core import Det3DDataSample
from .base_3droi_head import Base3DRoIHead from .base_3droi_head import Base3DRoIHead
...@@ -16,17 +21,15 @@ class H3DRoIHead(Base3DRoIHead): ...@@ -16,17 +21,15 @@ class H3DRoIHead(Base3DRoIHead):
""" """
def __init__(self, def __init__(self,
primitive_list, primitive_list: List[dict],
bbox_head=None, bbox_head: dict = None,
train_cfg=None, train_cfg: dict = None,
test_cfg=None, test_cfg: dict = None,
pretrained=None, init_cfg: dict = None):
init_cfg=None):
super(H3DRoIHead, self).__init__( super(H3DRoIHead, self).__init__(
bbox_head=bbox_head, bbox_head=bbox_head,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
pretrained=pretrained,
init_cfg=init_cfg) init_cfg=init_cfg)
# Primitive module # Primitive module
assert len(primitive_list) == 3 assert len(primitive_list) == 3
...@@ -39,8 +42,14 @@ class H3DRoIHead(Base3DRoIHead): ...@@ -39,8 +42,14 @@ class H3DRoIHead(Base3DRoIHead):
one.""" one."""
pass pass
def init_bbox_head(self, bbox_head): def init_bbox_head(self, dummy_args, bbox_head):
"""Initialize box head.""" """Initialize box head.
Args:
dummy_args (optional): Just to compatible with
the interface in base class
bbox_head (dict): Config for bbox head.
"""
bbox_head['train_cfg'] = self.train_cfg bbox_head['train_cfg'] = self.train_cfg
bbox_head['test_cfg'] = self.test_cfg bbox_head['test_cfg'] = self.test_cfg
self.bbox_head = MODELS.build(bbox_head) self.bbox_head = MODELS.build(bbox_head)
...@@ -49,111 +58,73 @@ class H3DRoIHead(Base3DRoIHead): ...@@ -49,111 +58,73 @@ class H3DRoIHead(Base3DRoIHead):
"""Initialize assigner and sampler.""" """Initialize assigner and sampler."""
pass pass
def forward_train(self, def loss(self, points: List[Tensor], feats_dict: dict,
feats_dict, batch_data_samples: List[Det3DDataSample], **kwargs):
img_metas,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask,
pts_instance_mask,
gt_bboxes_ignore=None):
"""Training forward function of PartAggregationROIHead. """Training forward function of PartAggregationROIHead.
Args: Args:
feats_dict (dict): Contains features from the first stage. points (list[torch.Tensor]): Point cloud of each sample.
img_metas (list[dict]): Contain pcd and img's meta info. feats_dict (dict): Dict of feature.
points (list[torch.Tensor]): Input points. batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth Samples. It usually includes information such as
bboxes of each sample. `gt_instance_3d`.
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
pts_semantic_mask (list[torch.Tensor]): Point-wise
semantic mask.
pts_instance_mask (list[torch.Tensor]): Point-wise
instance mask.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding boxes to ignore.
Returns: Returns:
dict: losses from each head. dict: losses from each head.
""" """
losses = dict() losses = dict()
sample_mod = self.train_cfg.sample_mod primitive_loss_inputs = (points, feats_dict, batch_data_samples)
assert sample_mod in ['vote', 'seed', 'random'] # note the feats_dict would be added new key and value in each head.
result_z = self.primitive_z(feats_dict, sample_mod)
feats_dict.update(result_z)
result_xy = self.primitive_xy(feats_dict, sample_mod)
feats_dict.update(result_xy)
result_line = self.primitive_line(feats_dict, sample_mod)
feats_dict.update(result_line)
primitive_loss_inputs = (feats_dict, points, gt_bboxes_3d,
gt_labels_3d, pts_semantic_mask,
pts_instance_mask, img_metas,
gt_bboxes_ignore)
loss_z = self.primitive_z.loss(*primitive_loss_inputs) loss_z = self.primitive_z.loss(*primitive_loss_inputs)
losses.update(loss_z)
loss_xy = self.primitive_xy.loss(*primitive_loss_inputs) loss_xy = self.primitive_xy.loss(*primitive_loss_inputs)
losses.update(loss_xy)
loss_line = self.primitive_line.loss(*primitive_loss_inputs) loss_line = self.primitive_line.loss(*primitive_loss_inputs)
losses.update(loss_z)
losses.update(loss_xy)
losses.update(loss_line) losses.update(loss_line)
targets = feats_dict.pop('targets') targets = feats_dict.pop('targets')
bbox_results = self.bbox_head(feats_dict, sample_mod) bbox_loss = self.bbox_head.loss(
points,
feats_dict.update(bbox_results) feats_dict,
bbox_loss = self.bbox_head.loss(feats_dict, points, gt_bboxes_3d, rpn_targets=targets,
gt_labels_3d, pts_semantic_mask, batch_data_samples=batch_data_samples)
pts_instance_mask, img_metas, targets,
gt_bboxes_ignore)
losses.update(bbox_loss) losses.update(bbox_loss)
return losses return losses
def simple_test(self, feats_dict, img_metas, points, rescale=False): def predict(self,
"""Simple testing forward function of PartAggregationROIHead. points: List[Tensor],
feats_dict: Dict[str, Tensor],
Note: batch_data_samples: List[Det3DDataSample],
This function assumes that the batch size is 1 suffix='_optimized',
**kwargs) -> List[InstanceData]:
"""
Args: Args:
feats_dict (dict): Contains features from the first stage. points (list[tensor]): Point clouds of multiple samples.
img_metas (list[dict]): Contain pcd and img's meta info. feats_dict (dict): Features from FPN or backbone..
points (torch.Tensor): Input points. batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
rescale (bool): Whether to rescale results. Samples. It usually includes meta information of data.
Returns: Returns:
dict: Bbox results of one frame. list[:obj:`InstanceData`]: List of processed predictions. Each
InstanceData contains 3d Bounding boxes and corresponding
scores and labels.
""" """
sample_mod = self.test_cfg.sample_mod
assert sample_mod in ['vote', 'seed', 'random']
result_z = self.primitive_z(feats_dict, sample_mod) result_z = self.primitive_z(feats_dict)
feats_dict.update(result_z) feats_dict.update(result_z)
result_xy = self.primitive_xy(feats_dict, sample_mod) result_xy = self.primitive_xy(feats_dict)
feats_dict.update(result_xy) feats_dict.update(result_xy)
result_line = self.primitive_line(feats_dict, sample_mod) result_line = self.primitive_line(feats_dict)
feats_dict.update(result_line) feats_dict.update(result_line)
bbox_preds = self.bbox_head(feats_dict, sample_mod) bbox_preds = self.bbox_head(feats_dict)
feats_dict.update(bbox_preds) feats_dict.update(bbox_preds)
bbox_list = self.bbox_head.get_bboxes( results_list = self.bbox_head.predict(
points, points, feats_dict, batch_data_samples, suffix=suffix)
feats_dict,
img_metas, return results_list
rescale=rescale,
suffix='_optimized')
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.ops import furthest_point_sample from mmcv.ops import furthest_point_sample
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmengine import InstanceData
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.models.builder import build_loss from mmdet3d.core import Det3DDataSample
from mmdet3d.models.model_utils import VoteModule from mmdet3d.models.model_utils import VoteModule
from mmdet3d.ops import build_sa_module from mmdet3d.ops import build_sa_module
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
...@@ -40,24 +43,25 @@ class PrimitiveHead(BaseModule): ...@@ -40,24 +43,25 @@ class PrimitiveHead(BaseModule):
""" """
def __init__(self, def __init__(self,
num_dims, num_dims: int,
num_classes, num_classes: int,
primitive_mode, primitive_mode: str,
train_cfg=None, train_cfg: dict = None,
test_cfg=None, test_cfg: dict = None,
vote_module_cfg=None, vote_module_cfg: dict = None,
vote_aggregation_cfg=None, vote_aggregation_cfg: dict = None,
feat_channels=(128, 128), feat_channels: tuple = (128, 128),
upper_thresh=100.0, upper_thresh: float = 100.0,
surface_thresh=0.5, surface_thresh: float = 0.5,
conv_cfg=dict(type='Conv1d'), conv_cfg: dict = dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg: dict = dict(type='BN1d'),
objectness_loss=None, objectness_loss: dict = None,
center_loss=None, center_loss: dict = None,
semantic_reg_loss=None, semantic_reg_loss: dict = None,
semantic_cls_loss=None, semantic_cls_loss: dict = None,
init_cfg=None): init_cfg: dict = None):
super(PrimitiveHead, self).__init__(init_cfg=init_cfg) super(PrimitiveHead, self).__init__(init_cfg=init_cfg)
# bounding boxes centers, face centers and edge centers
assert primitive_mode in ['z', 'xy', 'line'] assert primitive_mode in ['z', 'xy', 'line']
# The dimension of primitive semantic information. # The dimension of primitive semantic information.
self.num_dims = num_dims self.num_dims = num_dims
...@@ -70,10 +74,10 @@ class PrimitiveHead(BaseModule): ...@@ -70,10 +74,10 @@ class PrimitiveHead(BaseModule):
self.upper_thresh = upper_thresh self.upper_thresh = upper_thresh
self.surface_thresh = surface_thresh self.surface_thresh = surface_thresh
self.objectness_loss = build_loss(objectness_loss) self.loss_objectness = MODELS.build(objectness_loss)
self.center_loss = build_loss(center_loss) self.loss_center = MODELS.build(center_loss)
self.semantic_reg_loss = build_loss(semantic_reg_loss) self.loss_semantic_reg = MODELS.build(semantic_reg_loss)
self.semantic_cls_loss = build_loss(semantic_cls_loss) self.loss_semantic_cls = MODELS.build(semantic_cls_loss)
assert vote_aggregation_cfg['mlp_channels'][0] == vote_module_cfg[ assert vote_aggregation_cfg['mlp_channels'][0] == vote_module_cfg[
'in_channels'] 'in_channels']
...@@ -114,18 +118,26 @@ class PrimitiveHead(BaseModule): ...@@ -114,18 +118,26 @@ class PrimitiveHead(BaseModule):
self.conv_pred.add_module('conv_out', self.conv_pred.add_module('conv_out',
nn.Conv1d(prev_channel, conv_out_channel, 1)) nn.Conv1d(prev_channel, conv_out_channel, 1))
def forward(self, feats_dict, sample_mod): @property
def sample_mode(self):
if self.training:
sample_mode = self.train_cfg.sample_mode
else:
sample_mode = self.test_cfg.sample_mode
assert sample_mode in ['vote', 'seed', 'random']
return sample_mode
def forward(self, feats_dict):
"""Forward pass. """Forward pass.
Args: Args:
feats_dict (dict): Feature dict from backbone. feats_dict (dict): Feature dict from backbone.
sample_mod (str): Sample mode for vote aggregation layer.
valid modes are "vote", "seed" and "random".
Returns: Returns:
dict: Predictions of primitive head. dict: Predictions of primitive head.
""" """
assert sample_mod in ['vote', 'seed', 'random'] sample_mode = self.sample_mode
seed_points = feats_dict['fp_xyz_net0'][-1] seed_points = feats_dict['fp_xyz_net0'][-1]
seed_features = feats_dict['hd_feature'] seed_features = feats_dict['hd_feature']
...@@ -143,14 +155,14 @@ class PrimitiveHead(BaseModule): ...@@ -143,14 +155,14 @@ class PrimitiveHead(BaseModule):
results['vote_features_' + self.primitive_mode] = vote_features results['vote_features_' + self.primitive_mode] = vote_features
# 2. aggregate vote_points # 2. aggregate vote_points
if sample_mod == 'vote': if sample_mode == 'vote':
# use fps in vote_aggregation # use fps in vote_aggregation
sample_indices = None sample_indices = None
elif sample_mod == 'seed': elif sample_mode == 'seed':
# FPS on seed and choose the votes corresponding to the seeds # FPS on seed and choose the votes corresponding to the seeds
sample_indices = furthest_point_sample(seed_points, sample_indices = furthest_point_sample(seed_points,
self.num_proposal) self.num_proposal)
elif sample_mod == 'random': elif sample_mode == 'random':
# Random sampling from the votes # Random sampling from the votes
batch_size, num_seed = seed_points.shape[:2] batch_size, num_seed = seed_points.shape[:2]
sample_indices = torch.randint( sample_indices = torch.randint(
...@@ -185,63 +197,103 @@ class PrimitiveHead(BaseModule): ...@@ -185,63 +197,103 @@ class PrimitiveHead(BaseModule):
results['pred_' + self.primitive_mode + '_center'] = center results['pred_' + self.primitive_mode + '_center'] = center
return results return results
def loss(self, def loss(self, points: List[torch.Tensor], feats_dict: Dict[str,
bbox_preds, torch.Tensor],
points, batch_data_samples: List[Det3DDataSample], **kwargs) -> dict:
gt_bboxes_3d, """
gt_labels_3d, Args:
pts_semantic_mask=None, points (list[tensor]): Points cloud of multiple samples.
pts_instance_mask=None, feats_dict (dict): Predictions from backbone or FPN.
img_metas=None, batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
gt_bboxes_ignore=None): contains the meta information of each sample and
corresponding annotations.
Returns:
dict: A dictionary of loss components.
"""
preds = self(feats_dict)
feats_dict.update(preds)
batch_gt_instance_3d = []
batch_gt_instances_ignore = []
batch_input_metas = []
batch_pts_semantic_mask = []
batch_pts_instance_mask = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
batch_gt_instance_3d.append(data_sample.gt_instances_3d)
batch_gt_instances_ignore.append(
data_sample.get('ignored_instances', None))
batch_pts_semantic_mask.append(
data_sample.gt_pts_seg.get('pts_semantic_mask', None))
batch_pts_instance_mask.append(
data_sample.gt_pts_seg.get('pts_instance_mask', None))
loss_inputs = (points, feats_dict, batch_gt_instance_3d)
losses = self.loss_by_feat(
*loss_inputs,
batch_pts_semantic_mask=batch_pts_semantic_mask,
batch_pts_instance_mask=batch_pts_instance_mask,
batch_gt_instances_ignore=batch_gt_instances_ignore,
)
return losses
def loss_by_feat(
self,
points: List[torch.Tensor],
feats_dict: dict,
batch_gt_instances_3d: List[InstanceData],
batch_pts_semantic_mask: Optional[List[torch.Tensor]] = None,
batch_pts_instance_mask: Optional[List[torch.Tensor]] = None,
**kwargs):
"""Compute loss. """Compute loss.
Args: Args:
bbox_preds (dict): Predictions from forward of primitive head.
points (list[torch.Tensor]): Input points. points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth feats_dict (dict): Predictions of previous modules.
bboxes of each sample. batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_labels_3d (list[torch.Tensor]): Labels of each sample. gt_instances. It usually includes ``bboxes`` and ``labels``
pts_semantic_mask (list[torch.Tensor]): Point-wise attributes.
semantic mask. batch_pts_semantic_mask (list[tensor]): Semantic mask
pts_instance_mask (list[torch.Tensor]): Point-wise of points cloud. Defaults to None.
instance mask. batch_pts_semantic_mask (list[tensor]): Instance mask
img_metas (list[dict]): Contain pcd and img's meta info. of points cloud. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor]): Specify batch_input_metas (list[dict]): Contain pcd and img's meta info.
which bounding. ret_target (bool): Return targets or not. Defaults to False.
Returns: Returns:
dict: Losses of Primitive Head. dict: Losses of Primitive Head.
""" """
targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, targets = self.get_targets(points, feats_dict, batch_gt_instances_3d,
bbox_preds) batch_pts_semantic_mask,
batch_pts_instance_mask)
(point_mask, point_offset, gt_primitive_center, gt_primitive_semantic, (point_mask, point_offset, gt_primitive_center, gt_primitive_semantic,
gt_sem_cls_label, gt_primitive_mask) = targets gt_sem_cls_label, gt_primitive_mask) = targets
losses = {} losses = {}
# Compute the loss of primitive existence flag # Compute the loss of primitive existence flag
pred_flag = bbox_preds['pred_flag_' + self.primitive_mode] pred_flag = feats_dict['pred_flag_' + self.primitive_mode]
flag_loss = self.objectness_loss(pred_flag, gt_primitive_mask.long()) flag_loss = self.loss_objectness(pred_flag, gt_primitive_mask.long())
losses['flag_loss_' + self.primitive_mode] = flag_loss losses['flag_loss_' + self.primitive_mode] = flag_loss
# calculate vote loss # calculate vote loss
vote_loss = self.vote_module.get_loss( vote_loss = self.vote_module.get_loss(
bbox_preds['seed_points'], feats_dict['seed_points'],
bbox_preds['vote_' + self.primitive_mode], feats_dict['vote_' + self.primitive_mode],
bbox_preds['seed_indices'], point_mask, point_offset) feats_dict['seed_indices'], point_mask, point_offset)
losses['vote_loss_' + self.primitive_mode] = vote_loss losses['vote_loss_' + self.primitive_mode] = vote_loss
num_proposal = bbox_preds['aggregated_points_' + num_proposal = feats_dict['aggregated_points_' +
self.primitive_mode].shape[1] self.primitive_mode].shape[1]
primitive_center = bbox_preds['center_' + self.primitive_mode] primitive_center = feats_dict['center_' + self.primitive_mode]
if self.primitive_mode != 'line': if self.primitive_mode != 'line':
primitive_semantic = bbox_preds['size_residuals_' + primitive_semantic = feats_dict['size_residuals_' +
self.primitive_mode].contiguous() self.primitive_mode].contiguous()
else: else:
primitive_semantic = None primitive_semantic = None
semancitc_scores = bbox_preds['sem_cls_scores_' + semancitc_scores = feats_dict['sem_cls_scores_' +
self.primitive_mode].transpose(2, 1) self.primitive_mode].transpose(2, 1)
gt_primitive_mask = gt_primitive_mask / \ gt_primitive_mask = gt_primitive_mask / \
...@@ -256,44 +308,61 @@ class PrimitiveHead(BaseModule): ...@@ -256,44 +308,61 @@ class PrimitiveHead(BaseModule):
return losses return losses
def get_targets(self, def get_targets(
points, self,
gt_bboxes_3d, points,
gt_labels_3d, bbox_preds: Optional[dict] = None,
pts_semantic_mask=None, batch_gt_instances_3d: List[InstanceData] = None,
pts_instance_mask=None, batch_pts_semantic_mask: List[torch.Tensor] = None,
bbox_preds=None): batch_pts_instance_mask: List[torch.Tensor] = None,
):
"""Generate targets of primitive head. """Generate targets of primitive head.
Args: Args:
points (list[torch.Tensor]): Points of each batch. points (list[torch.Tensor]): Points of each batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth bbox_preds (torch.Tensor): Bounding box predictions of
bboxes of each batch. primitive head.
gt_labels_3d (list[torch.Tensor]): Labels of each batch. batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
pts_semantic_mask (list[torch.Tensor]): Point-wise semantic gt_instances. It usually includes ``bboxes_3d`` and
label of each batch. ``labels_3d`` attributes.
pts_instance_mask (list[torch.Tensor]): Point-wise instance batch_pts_semantic_mask (list[tensor]): Semantic gt mask for
label of each batch. multiple images.
bbox_preds (dict): Predictions from forward of primitive head. batch_pts_instance_mask (list[tensor]): Instance gt mask for
multiple images.
Returns: Returns:
tuple[torch.Tensor]: Targets of primitive head. tuple[torch.Tensor]: Targets of primitive head.
""" """
for index in range(len(gt_labels_3d)): batch_gt_labels_3d = [
if len(gt_labels_3d[index]) == 0: gt_instances_3d.labels_3d
fake_box = gt_bboxes_3d[index].tensor.new_zeros( for gt_instances_3d in batch_gt_instances_3d
1, gt_bboxes_3d[index].tensor.shape[-1]) ]
gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box) batch_gt_bboxes_3d = [
gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1) gt_instances_3d.bboxes_3d
for gt_instances_3d in batch_gt_instances_3d
if pts_semantic_mask is None: ]
pts_semantic_mask = [None for i in range(len(gt_labels_3d))] for index in range(len(batch_gt_labels_3d)):
pts_instance_mask = [None for i in range(len(gt_labels_3d))] if len(batch_gt_labels_3d[index]) == 0:
fake_box = batch_gt_bboxes_3d[index].tensor.new_zeros(
1, batch_gt_bboxes_3d[index].tensor.shape[-1])
batch_gt_bboxes_3d[index] = batch_gt_bboxes_3d[index].new_box(
fake_box)
batch_gt_labels_3d[index] = batch_gt_labels_3d[
index].new_zeros(1)
if batch_pts_semantic_mask is None:
batch_pts_semantic_mask = [
None for _ in range(len(batch_gt_labels_3d))
]
batch_pts_instance_mask = [
None for _ in range(len(batch_gt_labels_3d))
]
(point_mask, point_sem, (point_mask, point_sem,
point_offset) = multi_apply(self.get_targets_single, points, point_offset) = multi_apply(self.get_targets_single, points,
gt_bboxes_3d, gt_labels_3d, batch_gt_bboxes_3d, batch_gt_labels_3d,
pts_semantic_mask, pts_instance_mask) batch_pts_semantic_mask,
batch_pts_instance_mask)
point_mask = torch.stack(point_mask) point_mask = torch.stack(point_mask)
point_sem = torch.stack(point_sem) point_sem = torch.stack(point_sem)
...@@ -759,7 +828,7 @@ class PrimitiveHead(BaseModule): ...@@ -759,7 +828,7 @@ class PrimitiveHead(BaseModule):
vote_xyz_reshape = primitive_center.view(batch_size * num_proposal, -1, vote_xyz_reshape = primitive_center.view(batch_size * num_proposal, -1,
3) 3)
center_loss = self.center_loss( center_loss = self.loss_center(
vote_xyz_reshape, vote_xyz_reshape,
gt_primitive_center, gt_primitive_center,
dst_weight=gt_primitive_mask.view(batch_size * num_proposal, 1))[1] dst_weight=gt_primitive_mask.view(batch_size * num_proposal, 1))[1]
...@@ -767,7 +836,7 @@ class PrimitiveHead(BaseModule): ...@@ -767,7 +836,7 @@ class PrimitiveHead(BaseModule):
if self.primitive_mode != 'line': if self.primitive_mode != 'line':
size_xyz_reshape = primitive_semantic.view( size_xyz_reshape = primitive_semantic.view(
batch_size * num_proposal, -1, self.num_dims).contiguous() batch_size * num_proposal, -1, self.num_dims).contiguous()
size_loss = self.semantic_reg_loss( size_loss = self.loss_semantic_reg(
size_xyz_reshape, size_xyz_reshape,
gt_primitive_semantic, gt_primitive_semantic,
dst_weight=gt_primitive_mask.view(batch_size * num_proposal, dst_weight=gt_primitive_mask.view(batch_size * num_proposal,
...@@ -776,7 +845,7 @@ class PrimitiveHead(BaseModule): ...@@ -776,7 +845,7 @@ class PrimitiveHead(BaseModule):
size_loss = center_loss.new_tensor(0.0) size_loss = center_loss.new_tensor(0.0)
# Semantic cls loss # Semantic cls loss
sem_cls_loss = self.semantic_cls_loss( sem_cls_loss = self.loss_semantic_cls(
semantic_scores, gt_sem_cls_label, weight=gt_primitive_mask) semantic_scores, gt_sem_cls_label, weight=gt_primitive_mask)
return center_loss, size_loss, sem_cls_loss return center_loss, size_loss, sem_cls_loss
......
import unittest
import torch
from mmengine import DefaultScope
from mmdet3d.registry import MODELS
from tests.utils.model_utils import (_create_detector_inputs,
_get_detector_cfg, _setup_seed)
class TestH3D(unittest.TestCase):
def test_h3dnet(self):
import mmdet3d.models
assert hasattr(mmdet3d.models, 'H3DNet')
DefaultScope.get_instance('test_H3DNet', scope_name='mmdet3d')
_setup_seed(0)
voxel_net_cfg = _get_detector_cfg(
'h3dnet/h3dnet_3x8_scannet-3d-18class.py')
model = MODELS.build(voxel_net_cfg)
num_gt_instance = 5
data = [
_create_detector_inputs(
num_gt_instance=num_gt_instance,
points_feat_dim=4,
bboxes_3d_type='depth',
with_pts_semantic_mask=True,
with_pts_instance_mask=True)
]
if torch.cuda.is_available():
model = model.cuda()
# test simple_test
with torch.no_grad():
batch_inputs, data_samples = model.data_preprocessor(
data, True)
results = model.forward(
batch_inputs, data_samples, mode='predict')
self.assertEqual(len(results), len(data))
self.assertIn('bboxes_3d', results[0].pred_instances_3d)
self.assertIn('scores_3d', results[0].pred_instances_3d)
self.assertIn('labels_3d', results[0].pred_instances_3d)
# save the memory
with torch.no_grad():
losses = model.forward(batch_inputs, data_samples, mode='loss')
self.assertGreater(losses['vote_loss'], 0)
self.assertGreater(losses['objectness_loss'], 0)
self.assertGreater(losses['center_loss'], 0)
...@@ -7,7 +7,8 @@ import numpy as np ...@@ -7,7 +7,8 @@ import numpy as np
import torch import torch
from mmengine import InstanceData from mmengine import InstanceData
from mmdet3d.core import Det3DDataSample, LiDARInstance3DBoxes, PointData from mmdet3d.core import (CameraInstance3DBoxes, DepthInstance3DBoxes,
Det3DDataSample, LiDARInstance3DBoxes, PointData)
def _setup_seed(seed): def _setup_seed(seed):
...@@ -71,19 +72,24 @@ def _get_detector_cfg(fname): ...@@ -71,19 +72,24 @@ def _get_detector_cfg(fname):
return model return model
def _create_detector_inputs( def _create_detector_inputs(seed=0,
seed=0, with_points=True,
with_points=True, with_img=False,
with_img=False, num_gt_instance=20,
num_gt_instance=20, num_points=10,
num_points=10, points_feat_dim=4,
points_feat_dim=4, num_classes=3,
num_classes=3, gt_bboxes_dim=7,
gt_bboxes_dim=7, with_pts_semantic_mask=False,
with_pts_semantic_mask=False, with_pts_instance_mask=False,
with_pts_instance_mask=False, bboxes_3d_type='lidar'):
):
_setup_seed(seed) _setup_seed(seed)
assert bboxes_3d_type in ('lidar', 'depth', 'cam')
bbox_3d_class = {
'lidar': LiDARInstance3DBoxes,
'depth': DepthInstance3DBoxes,
'cam': CameraInstance3DBoxes
}
if with_points: if with_points:
points = torch.rand([num_points, points_feat_dim]) points = torch.rand([num_points, points_feat_dim])
else: else:
...@@ -93,12 +99,13 @@ def _create_detector_inputs( ...@@ -93,12 +99,13 @@ def _create_detector_inputs(
else: else:
img = None img = None
inputs_dict = dict(img=img, points=points) inputs_dict = dict(img=img, points=points)
gt_instance_3d = InstanceData() gt_instance_3d = InstanceData()
gt_instance_3d.bboxes_3d = LiDARInstance3DBoxes( gt_instance_3d.bboxes_3d = bbox_3d_class[bboxes_3d_type](
torch.rand([num_gt_instance, gt_bboxes_dim]), box_dim=gt_bboxes_dim) torch.rand([num_gt_instance, gt_bboxes_dim]), box_dim=gt_bboxes_dim)
gt_instance_3d.labels_3d = torch.randint(0, num_classes, [num_gt_instance]) gt_instance_3d.labels_3d = torch.randint(0, num_classes, [num_gt_instance])
data_sample = Det3DDataSample( data_sample = Det3DDataSample(
metainfo=dict(box_type_3d=LiDARInstance3DBoxes)) metainfo=dict(box_type_3d=bbox_3d_class[bboxes_3d_type]))
data_sample.gt_instances_3d = gt_instance_3d data_sample.gt_instances_3d = gt_instance_3d
data_sample.gt_pts_seg = PointData() data_sample.gt_pts_seg = PointData()
if with_pts_instance_mask: if with_pts_instance_mask:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment