Commit 0e17beab authored by jshilong's avatar jshilong Committed by ChaimZhu
Browse files

[REfactor]Refactor H3D

parent 9ebb75da
...@@ -30,7 +30,7 @@ primitive_z_cfg = dict( ...@@ -30,7 +30,7 @@ primitive_z_cfg = dict(
conv_cfg=dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
objectness_loss=dict( objectness_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
class_weight=[0.4, 0.6], class_weight=[0.4, 0.6],
reduction='mean', reduction='mean',
loss_weight=30.0), loss_weight=30.0),
...@@ -47,14 +47,16 @@ primitive_z_cfg = dict( ...@@ -47,14 +47,16 @@ primitive_z_cfg = dict(
loss_src_weight=0.5, loss_src_weight=0.5,
loss_dst_weight=0.5), loss_dst_weight=0.5),
semantic_cls_loss=dict( semantic_cls_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
train_cfg=dict( train_cfg=dict(
sample_mode='vote',
dist_thresh=0.2, dist_thresh=0.2,
var_thresh=1e-2, var_thresh=1e-2,
lower_thresh=1e-6, lower_thresh=1e-6,
num_point=100, num_point=100,
num_point_line=10, num_point_line=10,
line_thresh=0.2)) line_thresh=0.2),
test_cfg=dict(sample_mode='seed'))
primitive_xy_cfg = dict( primitive_xy_cfg = dict(
type='PrimitiveHead', type='PrimitiveHead',
...@@ -88,7 +90,7 @@ primitive_xy_cfg = dict( ...@@ -88,7 +90,7 @@ primitive_xy_cfg = dict(
conv_cfg=dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
objectness_loss=dict( objectness_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
class_weight=[0.4, 0.6], class_weight=[0.4, 0.6],
reduction='mean', reduction='mean',
loss_weight=30.0), loss_weight=30.0),
...@@ -105,14 +107,16 @@ primitive_xy_cfg = dict( ...@@ -105,14 +107,16 @@ primitive_xy_cfg = dict(
loss_src_weight=0.5, loss_src_weight=0.5,
loss_dst_weight=0.5), loss_dst_weight=0.5),
semantic_cls_loss=dict( semantic_cls_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
train_cfg=dict( train_cfg=dict(
sample_mode='vote',
dist_thresh=0.2, dist_thresh=0.2,
var_thresh=1e-2, var_thresh=1e-2,
lower_thresh=1e-6, lower_thresh=1e-6,
num_point=100, num_point=100,
num_point_line=10, num_point_line=10,
line_thresh=0.2)) line_thresh=0.2),
test_cfg=dict(sample_mode='seed'))
primitive_line_cfg = dict( primitive_line_cfg = dict(
type='PrimitiveHead', type='PrimitiveHead',
...@@ -146,7 +150,7 @@ primitive_line_cfg = dict( ...@@ -146,7 +150,7 @@ primitive_line_cfg = dict(
conv_cfg=dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
objectness_loss=dict( objectness_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
class_weight=[0.4, 0.6], class_weight=[0.4, 0.6],
reduction='mean', reduction='mean',
loss_weight=30.0), loss_weight=30.0),
...@@ -163,17 +167,20 @@ primitive_line_cfg = dict( ...@@ -163,17 +167,20 @@ primitive_line_cfg = dict(
loss_src_weight=1.0, loss_src_weight=1.0,
loss_dst_weight=1.0), loss_dst_weight=1.0),
semantic_cls_loss=dict( semantic_cls_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=2.0), type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=2.0),
train_cfg=dict( train_cfg=dict(
sample_mode='vote',
dist_thresh=0.2, dist_thresh=0.2,
var_thresh=1e-2, var_thresh=1e-2,
lower_thresh=1e-6, lower_thresh=1e-6,
num_point=100, num_point=100,
num_point_line=10, num_point_line=10,
line_thresh=0.2)) line_thresh=0.2),
test_cfg=dict(sample_mode='seed'))
model = dict( model = dict(
type='H3DNet', type='H3DNet',
data_preprocessor=dict(type='Det3DDataPreprocessor'),
backbone=dict( backbone=dict(
type='MultiBackbone', type='MultiBackbone',
num_streams=4, num_streams=4,
...@@ -221,10 +228,8 @@ model = dict( ...@@ -221,10 +228,8 @@ model = dict(
normalize_xyz=True), normalize_xyz=True),
pred_layer_cfg=dict( pred_layer_cfg=dict(
in_channels=128, shared_conv_channels=(128, 128), bias=True), in_channels=128, shared_conv_channels=(128, 128), bias=True),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict( objectness_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
class_weight=[0.2, 0.8], class_weight=[0.2, 0.8],
reduction='sum', reduction='sum',
loss_weight=5.0), loss_weight=5.0),
...@@ -235,15 +240,15 @@ model = dict( ...@@ -235,15 +240,15 @@ model = dict(
loss_src_weight=10.0, loss_src_weight=10.0,
loss_dst_weight=10.0), loss_dst_weight=10.0),
dir_class_loss=dict( dir_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
dir_res_loss=dict( dir_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0), type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=10.0),
size_class_loss=dict( size_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
size_res_loss=dict( size_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0), type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=10.0),
semantic_loss=dict( semantic_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0)), type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0)),
roi_head=dict( roi_head=dict(
type='H3DRoIHead', type='H3DRoIHead',
primitive_list=[primitive_z_cfg, primitive_xy_cfg, primitive_line_cfg], primitive_list=[primitive_z_cfg, primitive_xy_cfg, primitive_line_cfg],
...@@ -267,7 +272,6 @@ model = dict( ...@@ -267,7 +272,6 @@ model = dict(
mlp_channels=[128 + 12, 128, 64, 32], mlp_channels=[128 + 12, 128, 64, 32],
use_xyz=True, use_xyz=True,
normalize_xyz=True), normalize_xyz=True),
feat_channels=(128, 128),
primitive_refine_channels=[128, 128, 128], primitive_refine_channels=[128, 128, 128],
upper_thresh=100.0, upper_thresh=100.0,
surface_thresh=0.5, surface_thresh=0.5,
...@@ -275,7 +279,7 @@ model = dict( ...@@ -275,7 +279,7 @@ model = dict(
conv_cfg=dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
objectness_loss=dict( objectness_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
class_weight=[0.2, 0.8], class_weight=[0.2, 0.8],
reduction='sum', reduction='sum',
loss_weight=5.0), loss_weight=5.0),
...@@ -286,41 +290,47 @@ model = dict( ...@@ -286,41 +290,47 @@ model = dict(
loss_src_weight=10.0, loss_src_weight=10.0,
loss_dst_weight=10.0), loss_dst_weight=10.0),
dir_class_loss=dict( dir_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1), type='mmdet.CrossEntropyLoss',
reduction='sum',
loss_weight=0.1),
dir_res_loss=dict( dir_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0), type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=10.0),
size_class_loss=dict( size_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1), type='mmdet.CrossEntropyLoss',
reduction='sum',
loss_weight=0.1),
size_res_loss=dict( size_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0), type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=10.0),
semantic_loss=dict( semantic_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1), type='mmdet.CrossEntropyLoss',
reduction='sum',
loss_weight=0.1),
cues_objectness_loss=dict( cues_objectness_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
class_weight=[0.3, 0.7], class_weight=[0.3, 0.7],
reduction='mean', reduction='mean',
loss_weight=5.0), loss_weight=5.0),
cues_semantic_loss=dict( cues_semantic_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
class_weight=[0.3, 0.7], class_weight=[0.3, 0.7],
reduction='mean', reduction='mean',
loss_weight=5.0), loss_weight=5.0),
proposal_objectness_loss=dict( proposal_objectness_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
class_weight=[0.2, 0.8], class_weight=[0.2, 0.8],
reduction='none', reduction='none',
loss_weight=5.0), loss_weight=5.0),
primitive_center_loss=dict( primitive_center_loss=dict(
type='MSELoss', reduction='none', loss_weight=1.0))), type='mmdet.MSELoss', reduction='none', loss_weight=1.0))),
# model training and testing settings # model training and testing settings
train_cfg=dict( train_cfg=dict(
rpn=dict( rpn=dict(
pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mod='vote'), pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mode='vote'),
rpn_proposal=dict(use_nms=False), rpn_proposal=dict(use_nms=False),
rcnn=dict( rcnn=dict(
pos_distance_thr=0.3, pos_distance_thr=0.3,
neg_distance_thr=0.6, neg_distance_thr=0.6,
sample_mod='vote', sample_mode='vote',
far_threshold=0.6, far_threshold=0.6,
near_threshold=0.3, near_threshold=0.3,
mask_surface_threshold=0.3, mask_surface_threshold=0.3,
...@@ -329,13 +339,13 @@ model = dict( ...@@ -329,13 +339,13 @@ model = dict(
label_line_threshold=0.3)), label_line_threshold=0.3)),
test_cfg=dict( test_cfg=dict(
rpn=dict( rpn=dict(
sample_mod='seed', sample_mode='seed',
nms_thr=0.25, nms_thr=0.25,
score_thr=0.05, score_thr=0.05,
per_class_proposal=True, per_class_proposal=True,
use_nms=False), use_nms=False),
rcnn=dict( rcnn=dict(
sample_mod='seed', sample_mode='seed',
nms_thr=0.25, nms_thr=0.25,
score_thr=0.05, score_thr=0.05,
per_class_proposal=True))) per_class_proposal=True)))
_base_ = [
'../_base_/datasets/scannet-3d-18class.py', '../_base_/models/h3dnet.py',
'../_base_/schedules/schedule_3x.py', '../_base_/default_runtime.py'
]
# model settings
model = dict(
rpn_head=dict(
num_classes=18,
bbox_coder=dict(
type='PartialBinBasedBBoxCoder',
num_sizes=18,
num_dir_bins=24,
with_rot=False,
mean_sizes=[[0.76966727, 0.8116021, 0.92573744],
[1.876858, 1.8425595, 1.1931566],
[0.61328, 0.6148609, 0.7182701],
[1.3955007, 1.5121545, 0.83443564],
[0.97949594, 1.0675149, 0.6329687],
[0.531663, 0.5955577, 1.7500148],
[0.9624706, 0.72462326, 1.1481868],
[0.83221924, 1.0490936, 1.6875663],
[0.21132214, 0.4206159, 0.5372846],
[1.4440073, 1.8970833, 0.26985747],
[1.0294262, 1.4040797, 0.87554324],
[1.3766412, 0.65521795, 1.6813129],
[0.6650819, 0.71111923, 1.298853],
[0.41999173, 0.37906948, 1.7513971],
[0.59359556, 0.5912492, 0.73919016],
[0.50867593, 0.50656086, 0.30136237],
[1.1511526, 1.0546296, 0.49706793],
[0.47535285, 0.49249494, 0.5802117]])),
roi_head=dict(
bbox_head=dict(
num_classes=18,
bbox_coder=dict(
type='PartialBinBasedBBoxCoder',
num_sizes=18,
num_dir_bins=24,
with_rot=False,
mean_sizes=[[0.76966727, 0.8116021, 0.92573744],
[1.876858, 1.8425595, 1.1931566],
[0.61328, 0.6148609, 0.7182701],
[1.3955007, 1.5121545, 0.83443564],
[0.97949594, 1.0675149, 0.6329687],
[0.531663, 0.5955577, 1.7500148],
[0.9624706, 0.72462326, 1.1481868],
[0.83221924, 1.0490936, 1.6875663],
[0.21132214, 0.4206159, 0.5372846],
[1.4440073, 1.8970833, 0.26985747],
[1.0294262, 1.4040797, 0.87554324],
[1.3766412, 0.65521795, 1.6813129],
[0.6650819, 0.71111923, 1.298853],
[0.41999173, 0.37906948, 1.7513971],
[0.59359556, 0.5912492, 0.73919016],
[0.50867593, 0.50656086, 0.30136237],
[1.1511526, 1.0546296, 0.49706793],
[0.47535285, 0.49249494, 0.5802117]]))))
train_dataloader = dict(
batch_size=3,
num_workers=2,
)
# yapf:disable
default_hooks = dict(
logger=dict(type='LoggerHook', interval=30)
)
# yapf:enable
...@@ -57,8 +57,13 @@ model = dict( ...@@ -57,8 +57,13 @@ model = dict(
[1.1511526, 1.0546296, 0.49706793], [1.1511526, 1.0546296, 0.49706793],
[0.47535285, 0.49249494, 0.5802117]])))) [0.47535285, 0.49249494, 0.5802117]]))))
data = dict(samples_per_gpu=3, workers_per_gpu=2) train_dataloader = dict(
batch_size=3,
num_workers=2,
)
# yapf:disable # yapf:disable
log_config = dict(interval=30) default_hooks = dict(
logger=dict(type='LoggerHook', interval=30)
)
# yapf:enable # yapf:enable
...@@ -5,11 +5,9 @@ from typing import Callable, List, Optional, Union ...@@ -5,11 +5,9 @@ from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
from mmdet3d.core import show_result
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.registry import DATASETS from mmdet3d.registry import DATASETS
from .det3d_dataset import Det3DDataset from .det3d_dataset import Det3DDataset
from .pipelines import Compose
from .seg3d_dataset import Seg3DDataset from .seg3d_dataset import Seg3DDataset
...@@ -151,46 +149,6 @@ class ScanNetDataset(Det3DDataset): ...@@ -151,46 +149,6 @@ class ScanNetDataset(Det3DDataset):
return ann_info return ann_info
def _build_default_pipeline(self):
"""Build the default pipeline for this dataset."""
pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='DefaultFormatBundle3D',
class_names=self.CLASSES,
with_label=False),
dict(type='Collect3D', keys=['points'])
]
return Compose(pipeline)
def show(self, results, out_dir, show=True, pipeline=None):
"""Results visualization.
Args:
results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result.
show (bool): Visualize the results online.
pipeline (list[dict], optional): raw data loading for showing.
Default: None.
"""
assert out_dir is not None, 'Expect out_dir, got none.'
pipeline = self._get_pipeline(pipeline)
for i, result in enumerate(results):
data_info = self.get_data_info[i]
pts_path = data_info['lidar_points']['lidar_path']
file_name = osp.split(pts_path)[-1].split('.')[0]
points = self._extract_data(i, pipeline, 'points').numpy()
gt_bboxes = self.get_ann_info(i)['gt_bboxes_3d'].tensor.numpy()
pred_bboxes = result['boxes_3d'].tensor.numpy()
show_result(points, gt_bboxes, pred_bboxes, out_dir, file_name,
show)
@DATASETS.register_module() @DATASETS.register_module()
class ScanNetSegDataset(Seg3DDataset): class ScanNetSegDataset(Seg3DDataset):
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from mmcv.ops import furthest_point_sample from mmcv.ops import furthest_point_sample
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmengine import ConfigDict, InstanceData from mmengine import ConfigDict, InstanceData
from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core.post_processing import aligned_3d_nms from mmdet3d.core.post_processing import aligned_3d_nms
...@@ -161,7 +162,7 @@ class VoteHead(BaseModule): ...@@ -161,7 +162,7 @@ class VoteHead(BaseModule):
points: List[torch.Tensor], points: List[torch.Tensor],
feats_dict: Dict[str, torch.Tensor], feats_dict: Dict[str, torch.Tensor],
batch_data_samples: List[Det3DDataSample], batch_data_samples: List[Det3DDataSample],
rescale=True, use_nms: bool = True,
**kwargs) -> List[InstanceData]: **kwargs) -> List[InstanceData]:
""" """
Args: Args:
...@@ -169,8 +170,8 @@ class VoteHead(BaseModule): ...@@ -169,8 +170,8 @@ class VoteHead(BaseModule):
feats_dict (dict): Features from FPN or backbone.. feats_dict (dict): Features from FPN or backbone..
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes meta information of data. Samples. It usually includes meta information of data.
rescale (bool): Whether rescale the resutls to use_nms (bool): Whether do the nms for predictions.
the original scale. Defaults to True.
Returns: Returns:
list[:obj:`InstanceData`]: List of processed predictions. Each list[:obj:`InstanceData`]: List of processed predictions. Each
...@@ -178,6 +179,9 @@ class VoteHead(BaseModule): ...@@ -178,6 +179,9 @@ class VoteHead(BaseModule):
scores and labels. scores and labels.
""" """
preds_dict = self(feats_dict) preds_dict = self(feats_dict)
# `preds_dict` can be used in H3DNET
feats_dict.update(preds_dict)
batch_size = len(batch_data_samples) batch_size = len(batch_data_samples)
batch_input_metas = [] batch_input_metas = []
for batch_index in range(batch_size): for batch_index in range(batch_size):
...@@ -185,12 +189,73 @@ class VoteHead(BaseModule): ...@@ -185,12 +189,73 @@ class VoteHead(BaseModule):
batch_input_metas.append(metainfo) batch_input_metas.append(metainfo)
results_list = self.predict_by_feat( results_list = self.predict_by_feat(
points, preds_dict, batch_input_metas, rescale=rescale, **kwargs) points, preds_dict, batch_input_metas, use_nms=use_nms, **kwargs)
return results_list return results_list
def loss(self, points: List[torch.Tensor], feats_dict: Dict[str, def loss_and_predict(self,
torch.Tensor], points: List[torch.Tensor],
batch_data_samples: List[Det3DDataSample], **kwargs) -> dict: feats_dict: Dict[str, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
ret_target: bool = False,
proposal_cfg: dict = None,
**kwargs) -> Tuple:
"""
Args:
points (list[tensor]): Points cloud of multiple samples.
feats_dict (dict): Predictions from backbone or FPN.
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each sample and
corresponding annotations.
ret_target (bool): Whether return the assigned target.
Defaults to False.
proposal_cfg (dict): Configure for proposal process.
Defaults to True.
Returns:
tuple: Contains loss and predictions after post-process.
"""
preds_dict = self.forward(feats_dict)
feats_dict.update(preds_dict)
batch_gt_instance_3d = []
batch_gt_instances_ignore = []
batch_input_metas = []
batch_pts_semantic_mask = []
batch_pts_instance_mask = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
batch_gt_instance_3d.append(data_sample.gt_instances_3d)
batch_gt_instances_ignore.append(
data_sample.get('ignored_instances', None))
batch_pts_semantic_mask.append(
data_sample.gt_pts_seg.get('pts_semantic_mask', None))
batch_pts_instance_mask.append(
data_sample.gt_pts_seg.get('pts_instance_mask', None))
loss_inputs = (points, preds_dict, batch_gt_instance_3d)
losses = self.loss_by_feat(
*loss_inputs,
batch_pts_semantic_mask=batch_pts_semantic_mask,
batch_pts_instance_mask=batch_pts_instance_mask,
batch_input_metas=batch_input_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore,
ret_target=ret_target,
**kwargs)
results_list = self.predict_by_feat(
points,
preds_dict,
batch_input_metas,
use_nms=proposal_cfg.use_nms,
**kwargs)
return losses, results_list
def loss(self,
points: List[torch.Tensor],
feats_dict: Dict[str, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
ret_target: bool = False,
**kwargs) -> dict:
""" """
Args: Args:
points (list[tensor]): Points cloud of multiple samples. points (list[tensor]): Points cloud of multiple samples.
...@@ -198,6 +263,8 @@ class VoteHead(BaseModule): ...@@ -198,6 +263,8 @@ class VoteHead(BaseModule):
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each sample and contains the meta information of each sample and
corresponding annotations. corresponding annotations.
ret_target (bool): Whether return the assigned target.
Defaults to False.
Returns: Returns:
dict: A dictionary of loss components. dict: A dictionary of loss components.
...@@ -224,7 +291,9 @@ class VoteHead(BaseModule): ...@@ -224,7 +291,9 @@ class VoteHead(BaseModule):
batch_pts_semantic_mask=batch_pts_semantic_mask, batch_pts_semantic_mask=batch_pts_semantic_mask,
batch_pts_instance_mask=batch_pts_instance_mask, batch_pts_instance_mask=batch_pts_instance_mask,
batch_input_metas=batch_input_metas, batch_input_metas=batch_input_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore) batch_gt_instances_ignore=batch_gt_instances_ignore,
ret_target=ret_target,
**kwargs)
return losses return losses
def forward(self, feat_dict: dict) -> dict: def forward(self, feat_dict: dict) -> dict:
...@@ -330,7 +399,7 @@ class VoteHead(BaseModule): ...@@ -330,7 +399,7 @@ class VoteHead(BaseModule):
batch_pts_semantic_mask (list[tensor]): Instance mask batch_pts_semantic_mask (list[tensor]): Instance mask
of points cloud. Defaults to None. of points cloud. Defaults to None.
batch_input_metas (list[dict]): Contain pcd and img's meta info. batch_input_metas (list[dict]): Contain pcd and img's meta info.
ret_target (bool): Return targets or not. ret_target (bool): Return targets or not. Defaults to False.
Returns: Returns:
dict: Losses of Votenet. dict: Losses of Votenet.
...@@ -671,9 +740,10 @@ class VoteHead(BaseModule): ...@@ -671,9 +740,10 @@ class VoteHead(BaseModule):
while using vote head in rpn stage. while using vote head in rpn stage.
Returns: Returns:
list[:obj:`InstanceData`]: List of processed predictions. Each list[:obj:`InstanceData`] or Tensor: Return list of processed
InstanceData cantains 3d Bounding boxes and corresponding predictions when `use_nms` is True. Each InstanceData cantains
scores and labels. 3d Bounding boxes and corresponding scores and labels.
Return raw bboxes when `use_nms` is False.
""" """
# decode boxes # decode boxes
stack_points = torch.stack(points) stack_points = torch.stack(points)
...@@ -683,9 +753,9 @@ class VoteHead(BaseModule): ...@@ -683,9 +753,9 @@ class VoteHead(BaseModule):
batch_size = bbox3d.shape[0] batch_size = bbox3d.shape[0]
results_list = list() results_list = list()
for b in range(batch_size): if use_nms:
temp_results = InstanceData() for b in range(batch_size):
if use_nms: temp_results = InstanceData()
bbox_selected, score_selected, labels = \ bbox_selected, score_selected, labels = \
self.multiclass_nms_single(obj_scores[b], self.multiclass_nms_single(obj_scores[b],
sem_scores[b], sem_scores[b],
...@@ -700,20 +770,15 @@ class VoteHead(BaseModule): ...@@ -700,20 +770,15 @@ class VoteHead(BaseModule):
temp_results.scores_3d = score_selected temp_results.scores_3d = score_selected
temp_results.labels_3d = labels temp_results.labels_3d = labels
results_list.append(temp_results) results_list.append(temp_results)
else:
bbox = batch_input_metas[b]['box_type_3d'](
bbox_selected,
box_dim=bbox_selected.shape[-1],
with_yaw=self.bbox_coder.with_rot)
temp_results.bboxes_3d = bbox
temp_results.obj_scores_3d = obj_scores[b]
temp_results.sem_scores_3d = obj_scores[b]
results_list.append(temp_results)
return results_list return results_list
else:
# TODO unify it when refactor the Augtest
return bbox3d
def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points, def multiclass_nms_single(self, obj_scores: Tensor, sem_scores: Tensor,
input_meta): bbox: Tensor, points: Tensor,
input_meta: dict) -> Tuple:
"""Multi-class nms in single batch. """Multi-class nms in single batch.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch import torch
from torch import Tensor
from mmdet3d.core import merge_aug_bboxes_3d
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from ...core import Det3DDataSample
from .two_stage import TwoStage3DDetector from .two_stage import TwoStage3DDetector
...@@ -11,17 +14,33 @@ class H3DNet(TwoStage3DDetector): ...@@ -11,17 +14,33 @@ class H3DNet(TwoStage3DDetector):
r"""H3DNet model. r"""H3DNet model.
Please refer to the `paper <https://arxiv.org/abs/2006.05682>`_ Please refer to the `paper <https://arxiv.org/abs/2006.05682>`_
Args:
backbone (dict): Config dict of detector's backbone.
neck (dict, optional): Config dict of neck. Defaults to None.
rpn_head (dict, optional): Config dict of rpn head. Defaults to None.
roi_head (dict, optional): Config dict of roi head. Defaults to None.
train_cfg (dict, optional): Config dict of training hyper-parameters.
Defaults to None.
test_cfg (dict, optional): Config dict of test hyper-parameters.
Defaults to None.
init_cfg (dict, optional): the config to control the
initialization. Default to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`BaseDataPreprocessor`. it usually includes,
``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
""" """
def __init__(self, def __init__(self,
backbone, backbone: dict,
neck=None, neck: Optional[dict] = None,
rpn_head=None, rpn_head: Optional[dict] = None,
roi_head=None, roi_head: Optional[dict] = None,
train_cfg=None, train_cfg: Optional[dict] = None,
test_cfg=None, test_cfg: Optional[dict] = None,
pretrained=None, init_cfg: Optional[dict] = None,
init_cfg=None): data_preprocessor: Optional[dict] = None,
**kwargs) -> None:
super(H3DNet, self).__init__( super(H3DNet, self).__init__(
backbone=backbone, backbone=backbone,
neck=neck, neck=neck,
...@@ -29,148 +48,110 @@ class H3DNet(TwoStage3DDetector): ...@@ -29,148 +48,110 @@ class H3DNet(TwoStage3DDetector):
roi_head=roi_head, roi_head=roi_head,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
pretrained=pretrained, init_cfg=init_cfg,
init_cfg=init_cfg) data_preprocessor=data_preprocessor,
**kwargs)
def forward_train(self,
points, def extract_feat(self, batch_inputs_dict: dict) -> None:
img_metas, """Directly extract features from the backbone+neck.
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
gt_bboxes_ignore=None):
"""Forward of training.
Args: Args:
points (list[torch.Tensor]): Points of each batch.
img_metas (list): Image metas. batch_inputs_dict (dict): The model input dict which include
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch. 'points'.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
pts_semantic_mask (list[torch.Tensor]): point-wise semantic - points (list[torch.Tensor]): Point cloud of each sample.
label of each batch.
pts_instance_mask (list[torch.Tensor]): point-wise instance
label of each batch.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding.
Returns: Returns:
dict: Losses. dict: Dict of feature.
""" """
points_cat = torch.stack(points) stack_points = torch.stack(batch_inputs_dict['points'])
x = self.backbone(stack_points)
if self.with_neck:
x = self.neck(x)
return x
def loss(self, batch_inputs_dict: Dict[str, Union[List, Tensor]],
batch_data_samples: List[Det3DDataSample], **kwargs) -> dict:
"""
Args:
batch_inputs_dict (dict): The model input dict which include
'points' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
feats_dict = self.extract_feat(batch_inputs_dict)
feats_dict = self.extract_feat(points_cat)
feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]] feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]]
feats_dict['fp_features'] = [feats_dict['hd_feature']] feats_dict['fp_features'] = [feats_dict['hd_feature']]
feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]] feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]]
losses = dict() losses = dict()
if self.with_rpn: if self.with_rpn:
rpn_outs = self.rpn_head(feats_dict, self.train_cfg.rpn.sample_mod)
feats_dict.update(rpn_outs)
rpn_loss_inputs = (points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, img_metas)
rpn_losses = self.rpn_head.loss(
rpn_outs,
*rpn_loss_inputs,
gt_bboxes_ignore=gt_bboxes_ignore,
ret_target=True)
feats_dict['targets'] = rpn_losses.pop('targets')
losses.update(rpn_losses)
# Generate rpn proposals
proposal_cfg = self.train_cfg.get('rpn_proposal', proposal_cfg = self.train_cfg.get('rpn_proposal',
self.test_cfg.rpn) self.test_cfg.rpn)
proposal_inputs = (points, rpn_outs, img_metas) # note, the feats_dict would be added new key & value in rpn_head
proposal_list = self.rpn_head.get_bboxes( rpn_losses, rpn_proposals = self.rpn_head.loss_and_predict(
*proposal_inputs, use_nms=proposal_cfg.use_nms) batch_inputs_dict['points'],
feats_dict['proposal_list'] = proposal_list feats_dict,
batch_data_samples,
ret_target=True,
proposal_cfg=proposal_cfg)
feats_dict['targets'] = rpn_losses.pop('targets')
losses.update(rpn_losses)
feats_dict['rpn_proposals'] = rpn_proposals
else: else:
raise NotImplementedError raise NotImplementedError
roi_losses = self.roi_head.forward_train(feats_dict, img_metas, points, roi_losses = self.roi_head.loss(batch_inputs_dict['points'],
gt_bboxes_3d, gt_labels_3d, feats_dict, batch_data_samples,
pts_semantic_mask, **kwargs)
pts_instance_mask,
gt_bboxes_ignore)
losses.update(roi_losses) losses.update(roi_losses)
return losses return losses
def simple_test(self, points, img_metas, imgs=None, rescale=False): def predict(
"""Forward of testing. self, batch_input_dict: Dict,
batch_data_samples: List[Det3DDataSample]
) -> List[Det3DDataSample]:
"""Get model predictions.
Args: Args:
points (list[torch.Tensor]): Points of each sample. points (list[torch.Tensor]): Points of each sample.
img_metas (list): Image metas. batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
rescale (bool): Whether to rescale results. contains the meta information of each sample and
corresponding annotations.
Returns: Returns:
list: Predicted 3d boxes. list: Predicted 3d boxes.
""" """
points_cat = torch.stack(points)
feats_dict = self.extract_feat(points_cat) feats_dict = self.extract_feat(batch_input_dict)
feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]] feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]]
feats_dict['fp_features'] = [feats_dict['hd_feature']] feats_dict['fp_features'] = [feats_dict['hd_feature']]
feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]] feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]]
if self.with_rpn: if self.with_rpn:
proposal_cfg = self.test_cfg.rpn proposal_cfg = self.test_cfg.rpn
rpn_outs = self.rpn_head(feats_dict, proposal_cfg.sample_mod) rpn_proposals = self.rpn_head.predict(
feats_dict.update(rpn_outs) batch_input_dict['points'],
# Generate rpn proposals feats_dict,
proposal_list = self.rpn_head.get_bboxes( batch_data_samples,
points, rpn_outs, img_metas, use_nms=proposal_cfg.use_nms) use_nms=proposal_cfg.use_nms)
feats_dict['proposal_list'] = proposal_list feats_dict['rpn_proposals'] = rpn_proposals
else: else:
raise NotImplementedError raise NotImplementedError
return self.roi_head.simple_test( results_list = self.roi_head.predict(
feats_dict, img_metas, points_cat, rescale=rescale) batch_input_dict['points'],
feats_dict,
def aug_test(self, points, img_metas, imgs=None, rescale=False): batch_data_samples,
"""Test with augmentation.""" suffix='_optimized')
points_cat = [torch.stack(pts) for pts in points] return self.convert_to_datasample(results_list)
feats_dict = self.extract_feats(points_cat, img_metas)
for feat_dict in feats_dict:
feat_dict['fp_xyz'] = [feat_dict['fp_xyz_net0'][-1]]
feat_dict['fp_features'] = [feat_dict['hd_feature']]
feat_dict['fp_indices'] = [feat_dict['fp_indices_net0'][-1]]
# only support aug_test for one sample
aug_bboxes = []
for feat_dict, pts_cat, img_meta in zip(feats_dict, points_cat,
img_metas):
if self.with_rpn:
proposal_cfg = self.test_cfg.rpn
rpn_outs = self.rpn_head(feat_dict, proposal_cfg.sample_mod)
feat_dict.update(rpn_outs)
# Generate rpn proposals
proposal_list = self.rpn_head.get_bboxes(
points, rpn_outs, img_metas, use_nms=proposal_cfg.use_nms)
feat_dict['proposal_list'] = proposal_list
else:
raise NotImplementedError
bbox_results = self.roi_head.simple_test(
feat_dict,
self.test_cfg.rcnn.sample_mod,
img_meta,
pts_cat,
rescale=rescale)
aug_bboxes.append(bbox_results)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.bbox_head.test_cfg)
return [merged_bboxes]
def extract_feats(self, points, img_metas):
"""Extract features of multiple samples."""
return [
self.extract_feat(pts, img_meta)
for pts, img_meta in zip(points, img_metas)
]
...@@ -56,12 +56,12 @@ class PointRCNN(TwoStage3DDetector): ...@@ -56,12 +56,12 @@ class PointRCNN(TwoStage3DDetector):
x = self.neck(x) x = self.neck(x)
return x return x
def forward_train(self, points, img_metas, gt_bboxes_3d, gt_labels_3d): def forward_train(self, points, input_metas, gt_bboxes_3d, gt_labels_3d):
"""Forward of training. """Forward of training.
Args: Args:
points (list[torch.Tensor]): Points of each batch. points (list[torch.Tensor]): Points of each batch.
img_metas (list[dict]): Meta information of each sample. input_metas (list[dict]): Meta information of each sample.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch. gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch. gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
...@@ -69,8 +69,8 @@ class PointRCNN(TwoStage3DDetector): ...@@ -69,8 +69,8 @@ class PointRCNN(TwoStage3DDetector):
dict: Losses. dict: Losses.
""" """
losses = dict() losses = dict()
points_cat = torch.stack(points) stack_points = torch.stack(points)
x = self.extract_feat(points_cat) x = self.extract_feat(stack_points)
# features for rcnn # features for rcnn
backbone_feats = x['fp_features'].clone() backbone_feats = x['fp_features'].clone()
...@@ -85,11 +85,11 @@ class PointRCNN(TwoStage3DDetector): ...@@ -85,11 +85,11 @@ class PointRCNN(TwoStage3DDetector):
points=points, points=points,
gt_bboxes_3d=gt_bboxes_3d, gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=gt_labels_3d, gt_labels_3d=gt_labels_3d,
img_metas=img_metas) input_metas=input_metas)
losses.update(rpn_loss) losses.update(rpn_loss)
bbox_list = self.rpn_head.get_bboxes(points_cat, bbox_preds, cls_preds, bbox_list = self.rpn_head.get_bboxes(stack_points, bbox_preds,
img_metas) cls_preds, input_metas)
proposal_list = [ proposal_list = [
dict( dict(
boxes_3d=bboxes, boxes_3d=bboxes,
...@@ -100,7 +100,7 @@ class PointRCNN(TwoStage3DDetector): ...@@ -100,7 +100,7 @@ class PointRCNN(TwoStage3DDetector):
] ]
rcnn_feats.update({'points_cls_preds': cls_preds}) rcnn_feats.update({'points_cls_preds': cls_preds})
roi_losses = self.roi_head.forward_train(rcnn_feats, img_metas, roi_losses = self.roi_head.forward_train(rcnn_feats, input_metas,
proposal_list, gt_bboxes_3d, proposal_list, gt_bboxes_3d,
gt_labels_3d) gt_labels_3d)
losses.update(roi_losses) losses.update(roi_losses)
...@@ -121,9 +121,9 @@ class PointRCNN(TwoStage3DDetector): ...@@ -121,9 +121,9 @@ class PointRCNN(TwoStage3DDetector):
Returns: Returns:
list: Predicted 3d boxes. list: Predicted 3d boxes.
""" """
points_cat = torch.stack(points) stack_points = torch.stack(points)
x = self.extract_feat(points_cat) x = self.extract_feat(stack_points)
# features for rcnn # features for rcnn
backbone_feats = x['fp_features'].clone() backbone_feats = x['fp_features'].clone()
backbone_xyz = x['fp_xyz'].clone() backbone_xyz = x['fp_xyz'].clone()
...@@ -132,7 +132,7 @@ class PointRCNN(TwoStage3DDetector): ...@@ -132,7 +132,7 @@ class PointRCNN(TwoStage3DDetector):
rcnn_feats.update({'points_cls_preds': cls_preds}) rcnn_feats.update({'points_cls_preds': cls_preds})
bbox_list = self.rpn_head.get_bboxes( bbox_list = self.rpn_head.get_bboxes(
points_cat, bbox_preds, cls_preds, img_metas, rescale=rescale) stack_points, bbox_preds, cls_preds, img_metas, rescale=rescale)
proposal_list = [ proposal_list = [
dict( dict(
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmengine import InstanceData
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core import build_bbox_coder from mmdet3d.core import BaseInstance3DBoxes, Det3DDataSample
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.core.post_processing import aligned_3d_nms from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import build_loss
from mmdet3d.models.losses import chamfer_distance from mmdet3d.models.losses import chamfer_distance
from mmdet3d.ops import build_sa_module from mmdet3d.ops import build_sa_module
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet.core import multi_apply from mmdet.core import multi_apply
...@@ -25,66 +28,73 @@ class H3DBboxHead(BaseModule): ...@@ -25,66 +28,73 @@ class H3DBboxHead(BaseModule):
line_matching_cfg (dict): Config for line primitive matching. line_matching_cfg (dict): Config for line primitive matching.
bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
decoding boxes. decoding boxes.
train_cfg (dict): Config for training. train_cfg (dict): Config for training. Defaults to None.
test_cfg (dict): Config for testing. test_cfg (dict): Config for testing. Defaults to None.
gt_per_seed (int): Number of ground truth votes generated gt_per_seed (int): Number of ground truth votes generated
from each seed point. from each seed point. Defaults to 1.
num_proposal (int): Number of proposal votes generated. num_proposal (int): Number of proposal votes generated.
feat_channels (tuple[int]): Convolution channels of Defaults to 256.
prediction layer.
primitive_feat_refine_streams (int): The number of mlps to primitive_feat_refine_streams (int): The number of mlps to
refine primitive feature. refine primitive feature. Defaults to 2.
primitive_refine_channels (tuple[int]): Convolution channels of primitive_refine_channels (tuple[int]): Convolution channels of
prediction layer. prediction layer. Defaults to [128, 128, 128].
upper_thresh (float): Threshold for line matching. upper_thresh (float): Threshold for line matching. Defaults to 100.
surface_thresh (float): Threshold for surface matching. surface_thresh (float): Threshold for surface matching.
line_thresh (float): Threshold for line matching. Defaults to 0.5.
line_thresh (float): Threshold for line matching. Defaults to 0.5.
conv_cfg (dict): Config of convolution in prediction layer. conv_cfg (dict): Config of convolution in prediction layer.
norm_cfg (dict): Config of BN in prediction layer. Defaults to None.
objectness_loss (dict): Config of objectness loss. norm_cfg (dict): Config of BN in prediction layer. Defaults to None.
center_loss (dict): Config of center loss. objectness_loss (dict): Config of objectness loss. Defaults to None.
center_loss (dict): Config of center loss. Defaults to None.
dir_class_loss (dict): Config of direction classification loss. dir_class_loss (dict): Config of direction classification loss.
Defaults to None.
dir_res_loss (dict): Config of direction residual regression loss. dir_res_loss (dict): Config of direction residual regression loss.
Defaults to None.
size_class_loss (dict): Config of size classification loss. size_class_loss (dict): Config of size classification loss.
Defaults to None.
size_res_loss (dict): Config of size residual regression loss. size_res_loss (dict): Config of size residual regression loss.
Defaults to None.
semantic_loss (dict): Config of point-wise semantic segmentation loss. semantic_loss (dict): Config of point-wise semantic segmentation loss.
Defaults to None.
cues_objectness_loss (dict): Config of cues objectness loss. cues_objectness_loss (dict): Config of cues objectness loss.
Defaults to None.
cues_semantic_loss (dict): Config of cues semantic loss. cues_semantic_loss (dict): Config of cues semantic loss.
Defaults to None.
proposal_objectness_loss (dict): Config of proposal objectness proposal_objectness_loss (dict): Config of proposal objectness
loss. loss. Defaults to None.
primitive_center_loss (dict): Config of primitive center regression primitive_center_loss (dict): Config of primitive center regression
loss. loss. Defaults to None.
""" """
def __init__(self, def __init__(self,
num_classes, num_classes: int,
suface_matching_cfg, suface_matching_cfg: dict,
line_matching_cfg, line_matching_cfg: dict,
bbox_coder, bbox_coder: dict,
train_cfg=None, train_cfg: Optional[dict] = None,
test_cfg=None, test_cfg: Optional[dict] = None,
gt_per_seed=1, gt_per_seed: int = 1,
num_proposal=256, num_proposal: int = 256,
feat_channels=(128, 128), primitive_feat_refine_streams: int = 2,
primitive_feat_refine_streams=2, primitive_refine_channels: List[int] = [128, 128, 128],
primitive_refine_channels=[128, 128, 128], upper_thresh: float = 100.0,
upper_thresh=100.0, surface_thresh: float = 0.5,
surface_thresh=0.5, line_thresh: float = 0.5,
line_thresh=0.5, conv_cfg: dict = dict(type='Conv1d'),
conv_cfg=dict(type='Conv1d'), norm_cfg: dict = dict(type='BN1d'),
norm_cfg=dict(type='BN1d'), objectness_loss: Optional[dict] = None,
objectness_loss=None, center_loss: Optional[dict] = None,
center_loss=None, dir_class_loss: Optional[dict] = None,
dir_class_loss=None, dir_res_loss: Optional[dict] = None,
dir_res_loss=None, size_class_loss: Optional[dict] = None,
size_class_loss=None, size_res_loss: Optional[dict] = None,
size_res_loss=None, semantic_loss: Optional[dict] = None,
semantic_loss=None, cues_objectness_loss: Optional[dict] = None,
cues_objectness_loss=None, cues_semantic_loss: Optional[dict] = None,
cues_semantic_loss=None, proposal_objectness_loss: Optional[dict] = None,
proposal_objectness_loss=None, primitive_center_loss: Optional[dict] = None,
primitive_center_loss=None, init_cfg: dict = None):
init_cfg=None):
super(H3DBboxHead, self).__init__(init_cfg=init_cfg) super(H3DBboxHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes self.num_classes = num_classes
self.train_cfg = train_cfg self.train_cfg = train_cfg
...@@ -96,22 +106,22 @@ class H3DBboxHead(BaseModule): ...@@ -96,22 +106,22 @@ class H3DBboxHead(BaseModule):
self.surface_thresh = surface_thresh self.surface_thresh = surface_thresh
self.line_thresh = line_thresh self.line_thresh = line_thresh
self.objectness_loss = build_loss(objectness_loss) self.loss_objectness = MODELS.build(objectness_loss)
self.center_loss = build_loss(center_loss) self.loss_center = MODELS.build(center_loss)
self.dir_class_loss = build_loss(dir_class_loss) self.loss_dir_class = MODELS.build(dir_class_loss)
self.dir_res_loss = build_loss(dir_res_loss) self.loss_dir_res = MODELS.build(dir_res_loss)
self.size_class_loss = build_loss(size_class_loss) self.loss_size_class = MODELS.build(size_class_loss)
self.size_res_loss = build_loss(size_res_loss) self.loss_size_res = MODELS.build(size_res_loss)
self.semantic_loss = build_loss(semantic_loss) self.loss_semantic = MODELS.build(semantic_loss)
self.bbox_coder = build_bbox_coder(bbox_coder) self.bbox_coder = TASK_UTILS.build(bbox_coder)
self.num_sizes = self.bbox_coder.num_sizes self.num_sizes = self.bbox_coder.num_sizes
self.num_dir_bins = self.bbox_coder.num_dir_bins self.num_dir_bins = self.bbox_coder.num_dir_bins
self.cues_objectness_loss = build_loss(cues_objectness_loss) self.loss_cues_objectness = MODELS.build(cues_objectness_loss)
self.cues_semantic_loss = build_loss(cues_semantic_loss) self.loss_cues_semantic = MODELS.build(cues_semantic_loss)
self.proposal_objectness_loss = build_loss(proposal_objectness_loss) self.loss_proposal_objectness = MODELS.build(proposal_objectness_loss)
self.primitive_center_loss = build_loss(primitive_center_loss) self.loss_primitive_center = MODELS.build(primitive_center_loss)
assert suface_matching_cfg['mlp_channels'][-1] == \ assert suface_matching_cfg['mlp_channels'][-1] == \
line_matching_cfg['mlp_channels'][-1] line_matching_cfg['mlp_channels'][-1]
...@@ -202,16 +212,14 @@ class H3DBboxHead(BaseModule): ...@@ -202,16 +212,14 @@ class H3DBboxHead(BaseModule):
bbox_coder['num_sizes'] * 4 + self.num_classes) bbox_coder['num_sizes'] * 4 + self.num_classes)
self.bbox_pred.append(nn.Conv1d(prev_channel, conv_out_channel, 1)) self.bbox_pred.append(nn.Conv1d(prev_channel, conv_out_channel, 1))
def forward(self, feats_dict, sample_mod): def forward(self, feats_dict: dict):
"""Forward pass. """Forward pass.
Args: Args:
feats_dict (dict): Feature dict from backbone. feats_dict (dict): Feature dict from backbone.
sample_mod (str): Sample mode for vote aggregation layer.
valid modes are "vote", "seed" and "random".
Returns: Returns:
dict: Predictions of vote head. dict: Predictions of head.
""" """
ret_dict = {} ret_dict = {}
aggregated_points = feats_dict['aggregated_points'] aggregated_points = feats_dict['aggregated_points']
...@@ -236,7 +244,7 @@ class H3DBboxHead(BaseModule): ...@@ -236,7 +244,7 @@ class H3DBboxHead(BaseModule):
dim=1) dim=1)
# Extract the surface and line centers of rpn proposals # Extract the surface and line centers of rpn proposals
rpn_proposals = feats_dict['proposal_list'] rpn_proposals = feats_dict['rpn_proposals']
rpn_proposals_bbox = DepthInstance3DBoxes( rpn_proposals_bbox = DepthInstance3DBoxes(
rpn_proposals.reshape(-1, 7).clone(), rpn_proposals.reshape(-1, 7).clone(),
box_dim=rpn_proposals.shape[-1], box_dim=rpn_proposals.shape[-1],
...@@ -310,36 +318,29 @@ class H3DBboxHead(BaseModule): ...@@ -310,36 +318,29 @@ class H3DBboxHead(BaseModule):
ret_dict[key + '_optimized'] = refine_decode_res[key] ret_dict[key + '_optimized'] = refine_decode_res[key]
return ret_dict return ret_dict
def loss(self, def loss(
bbox_preds, self,
points, points: List[Tensor],
gt_bboxes_3d, feats_dict: dict,
gt_labels_3d, rpn_targets: Tuple = None,
pts_semantic_mask=None, batch_data_samples: List[Det3DDataSample] = None,
pts_instance_mask=None, ):
img_metas=None, """
rpn_targets=None,
gt_bboxes_ignore=None):
"""Compute loss.
Args: Args:
bbox_preds (dict): Predictions from forward of h3d bbox head. points (list[tensor]): Points cloud of multiple samples.
points (list[torch.Tensor]): Input points. feats_dict (dict): Predictions from backbone or FPN.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth rpn_targets (Tuple, Optional): The target of sample from RPN.
bboxes of each sample. Defaults to None.
gt_labels_3d (list[torch.Tensor]): Labels of each sample. batch_data_samples (list[:obj:`Det3DDataSample`], Optional):
pts_semantic_mask (list[torch.Tensor]): Point-wise Each item contains the meta information of each sample
semantic mask. and corresponding annotations. Defaults to None.
pts_instance_mask (list[torch.Tensor]): Point-wise
instance mask.
img_metas (list[dict]): Contain pcd and img's meta info.
rpn_targets (Tuple) : Targets generated by rpn head.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding.
Returns: Returns:
dict: Losses of H3dnet. dict: A dictionary of loss components.
""" """
preds = self(feats_dict)
feats_dict.update(preds)
(vote_targets, vote_target_masks, size_class_targets, size_res_targets, (vote_targets, vote_target_masks, size_class_targets, size_res_targets,
dir_class_targets, dir_res_targets, center_targets, _, mask_targets, dir_class_targets, dir_res_targets, center_targets, _, mask_targets,
valid_gt_masks, objectness_targets, objectness_weights, valid_gt_masks, objectness_targets, objectness_weights,
...@@ -349,7 +350,7 @@ class H3DBboxHead(BaseModule): ...@@ -349,7 +350,7 @@ class H3DBboxHead(BaseModule):
# calculate refined proposal loss # calculate refined proposal loss
refined_proposal_loss = self.get_proposal_stage_loss( refined_proposal_loss = self.get_proposal_stage_loss(
bbox_preds, feats_dict,
size_class_targets, size_class_targets,
size_res_targets, size_res_targets,
dir_class_targets, dir_class_targets,
...@@ -364,36 +365,60 @@ class H3DBboxHead(BaseModule): ...@@ -364,36 +365,60 @@ class H3DBboxHead(BaseModule):
for key in refined_proposal_loss.keys(): for key in refined_proposal_loss.keys():
losses[key + '_optimized'] = refined_proposal_loss[key] losses[key + '_optimized'] = refined_proposal_loss[key]
batch_gt_instance_3d = []
batch_input_metas = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
batch_gt_instance_3d.append(data_sample.gt_instances_3d)
temp_loss = self.loss_by_feat(points, feats_dict, batch_gt_instance_3d)
losses.update(temp_loss)
return losses
def loss_by_feat(self, points: List[torch.Tensor], feats_dict: dict,
batch_gt_instances_3d: List[InstanceData],
**kwargs) -> dict:
"""Compute loss.
Args:
points (list[torch.Tensor]): Input points.
feats_dict (dict): Predictions from forward of vote head.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes`` and ``labels``
attributes.
Returns:
dict: Losses of H3DNet.
"""
bbox3d_optimized = self.bbox_coder.decode( bbox3d_optimized = self.bbox_coder.decode(
bbox_preds, suffix='_optimized') feats_dict, suffix='_optimized')
targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d, targets = self.get_targets(points, feats_dict, batch_gt_instances_3d)
pts_semantic_mask, pts_instance_mask,
bbox_preds)
(cues_objectness_label, cues_sem_label, proposal_objectness_label, (cues_objectness_label, cues_sem_label, proposal_objectness_label,
cues_mask, cues_match_mask, proposal_objectness_mask, cues_mask, cues_match_mask, proposal_objectness_mask,
cues_matching_label, obj_surface_line_center) = targets cues_matching_label, obj_surface_line_center) = targets
# match scores for each geometric primitive # match scores for each geometric primitive
objectness_scores = bbox_preds['matching_score'] objectness_scores = feats_dict['matching_score']
# match scores for the semantics of primitives # match scores for the semantics of primitives
objectness_scores_sem = bbox_preds['semantic_matching_score'] objectness_scores_sem = feats_dict['semantic_matching_score']
primitive_objectness_loss = self.cues_objectness_loss( primitive_objectness_loss = self.loss_cues_objectness(
objectness_scores.transpose(2, 1), objectness_scores.transpose(2, 1),
cues_objectness_label, cues_objectness_label,
weight=cues_mask, weight=cues_mask,
avg_factor=cues_mask.sum() + 1e-6) avg_factor=cues_mask.sum() + 1e-6)
primitive_sem_loss = self.cues_semantic_loss( primitive_sem_loss = self.loss_cues_semantic(
objectness_scores_sem.transpose(2, 1), objectness_scores_sem.transpose(2, 1),
cues_sem_label, cues_sem_label,
weight=cues_mask, weight=cues_mask,
avg_factor=cues_mask.sum() + 1e-6) avg_factor=cues_mask.sum() + 1e-6)
objectness_scores = bbox_preds['obj_scores_optimized'] objectness_scores = feats_dict['obj_scores_optimized']
objectness_loss_refine = self.proposal_objectness_loss( objectness_loss_refine = self.loss_proposal_objectness(
objectness_scores.transpose(2, 1), proposal_objectness_label) objectness_scores.transpose(2, 1), proposal_objectness_label)
primitive_matching_loss = (objectness_loss_refine * primitive_matching_loss = (objectness_loss_refine *
cues_match_mask).sum() / ( cues_match_mask).sum() / (
...@@ -419,7 +444,7 @@ class H3DBboxHead(BaseModule): ...@@ -419,7 +444,7 @@ class H3DBboxHead(BaseModule):
pred_surface_line_center = torch.cat( pred_surface_line_center = torch.cat(
(pred_obj_surface_center, pred_obj_line_center), 1) (pred_obj_surface_center, pred_obj_line_center), 1)
square_dist = self.primitive_center_loss(pred_surface_line_center, square_dist = self.loss_primitive_center(pred_surface_line_center,
obj_surface_line_center) obj_surface_line_center)
match_dist = torch.sqrt(square_dist.sum(dim=-1) + 1e-6) match_dist = torch.sqrt(square_dist.sum(dim=-1) + 1e-6)
...@@ -434,58 +459,102 @@ class H3DBboxHead(BaseModule): ...@@ -434,58 +459,102 @@ class H3DBboxHead(BaseModule):
primitive_sem_matching_loss=primitive_sem_matching_loss, primitive_sem_matching_loss=primitive_sem_matching_loss,
primitive_centroid_reg_loss=primitive_centroid_reg_loss) primitive_centroid_reg_loss=primitive_centroid_reg_loss)
losses.update(refined_loss) return refined_loss
return losses def predict(self,
points: List[torch.Tensor],
feats_dict: Dict[str, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
suffix='_optimized',
**kwargs) -> List[InstanceData]:
"""
Args:
points (list[tensor]): Point clouds of multiple samples.
feats_dict (dict): Features from FPN or backbone..
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes meta information of data.
suffix (str): suffix for tensor in feats_dict.
Defaults to '_optimized'.
def get_bboxes(self, Returns:
points, list[:obj:`InstanceData`]: List of processed predictions. Each
bbox_preds, InstanceData contains 3d Bounding boxes and corresponding
input_metas, scores and labels.
rescale=False, """
suffix=''): preds_dict = self(feats_dict)
# `preds_dict` can be used in H3DNET
feats_dict.update(preds_dict)
batch_size = len(batch_data_samples)
batch_input_metas = []
for batch_index in range(batch_size):
metainfo = batch_data_samples[batch_index].metainfo
batch_input_metas.append(metainfo)
results_list = self.predict_by_feat(
points, feats_dict, batch_input_metas, suffix=suffix, **kwargs)
return results_list
def predict_by_feat(self,
points: List[torch.Tensor],
feats_dict: dict,
batch_input_metas: List[dict],
suffix='_optimized',
**kwargs) -> List[InstanceData]:
"""Generate bboxes from vote head predictions. """Generate bboxes from vote head predictions.
Args: Args:
points (torch.Tensor): Input points. points (List[torch.Tensor]): Input points of multiple samples.
bbox_preds (dict): Predictions from vote head. feats_dict (dict): Predictions from previous components.
input_metas (list[dict]): Point cloud and image's meta info. batch_input_metas (list[dict]): Each item
rescale (bool): Whether to rescale bboxes. contains the meta information of each sample.
suffix (str): suffix for tensor in feats_dict.
Defaults to '_optimized'.
Returns: Returns:
list[tuple[torch.Tensor]]: Bounding boxes, scores and labels. list[:obj:`InstanceData`]: Return list of processed
predictions. Each InstanceData cantains
3d Bounding boxes and corresponding scores and labels.
""" """
# decode boxes # decode boxes
obj_scores = F.softmax( obj_scores = F.softmax(
bbox_preds['obj_scores' + suffix], dim=-1)[..., -1] feats_dict['obj_scores' + suffix], dim=-1)[..., -1]
sem_scores = F.softmax(bbox_preds['sem_scores'], dim=-1) sem_scores = F.softmax(feats_dict['sem_scores'], dim=-1)
prediction_collection = {} prediction_collection = {}
prediction_collection['center'] = bbox_preds['center' + suffix] prediction_collection['center'] = feats_dict['center' + suffix]
prediction_collection['dir_class'] = bbox_preds['dir_class'] prediction_collection['dir_class'] = feats_dict['dir_class']
prediction_collection['dir_res'] = bbox_preds['dir_res' + suffix] prediction_collection['dir_res'] = feats_dict['dir_res' + suffix]
prediction_collection['size_class'] = bbox_preds['size_class'] prediction_collection['size_class'] = feats_dict['size_class']
prediction_collection['size_res'] = bbox_preds['size_res' + suffix] prediction_collection['size_res'] = feats_dict['size_res' + suffix]
bbox3d = self.bbox_coder.decode(prediction_collection) bbox3d = self.bbox_coder.decode(prediction_collection)
batch_size = bbox3d.shape[0] batch_size = bbox3d.shape[0]
results = list() results_list = list()
points = torch.stack(points)
for b in range(batch_size): for b in range(batch_size):
temp_results = InstanceData()
bbox_selected, score_selected, labels = self.multiclass_nms_single( bbox_selected, score_selected, labels = self.multiclass_nms_single(
obj_scores[b], sem_scores[b], bbox3d[b], points[b, ..., :3], obj_scores[b], sem_scores[b], bbox3d[b], points[b, ..., :3],
input_metas[b]) batch_input_metas[b])
bbox = input_metas[b]['box_type_3d']( bbox = batch_input_metas[b]['box_type_3d'](
bbox_selected, bbox_selected,
box_dim=bbox_selected.shape[-1], box_dim=bbox_selected.shape[-1],
with_yaw=self.bbox_coder.with_rot) with_yaw=self.bbox_coder.with_rot)
results.append((bbox, score_selected, labels))
return results temp_results.bboxes_3d = bbox
temp_results.scores_3d = score_selected
temp_results.labels_3d = labels
results_list.append(temp_results)
return results_list
def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points, def multiclass_nms_single(self, obj_scores: Tensor, sem_scores: Tensor,
input_meta): bbox: Tensor, points: Tensor,
input_meta: dict) -> Tuple:
"""Multi-class nms in single batch. """Multi-class nms in single batch.
Args: Args:
...@@ -586,13 +655,13 @@ class H3DBboxHead(BaseModule): ...@@ -586,13 +655,13 @@ class H3DBboxHead(BaseModule):
dict: Losses of aggregation module. dict: Losses of aggregation module.
""" """
# calculate objectness loss # calculate objectness loss
objectness_loss = self.objectness_loss( objectness_loss = self.loss_objectness(
bbox_preds['obj_scores' + suffix].transpose(2, 1), bbox_preds['obj_scores' + suffix].transpose(2, 1),
objectness_targets, objectness_targets,
weight=objectness_weights) weight=objectness_weights)
# calculate center loss # calculate center loss
source2target_loss, target2source_loss = self.center_loss( source2target_loss, target2source_loss = self.loss_center(
bbox_preds['center' + suffix], bbox_preds['center' + suffix],
center_targets, center_targets,
src_weight=box_loss_weights, src_weight=box_loss_weights,
...@@ -600,7 +669,7 @@ class H3DBboxHead(BaseModule): ...@@ -600,7 +669,7 @@ class H3DBboxHead(BaseModule):
center_loss = source2target_loss + target2source_loss center_loss = source2target_loss + target2source_loss
# calculate direction class loss # calculate direction class loss
dir_class_loss = self.dir_class_loss( dir_class_loss = self.loss_dir_class(
bbox_preds['dir_class' + suffix].transpose(2, 1), bbox_preds['dir_class' + suffix].transpose(2, 1),
dir_class_targets, dir_class_targets,
weight=box_loss_weights) weight=box_loss_weights)
...@@ -612,11 +681,11 @@ class H3DBboxHead(BaseModule): ...@@ -612,11 +681,11 @@ class H3DBboxHead(BaseModule):
heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1) heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1)
dir_res_norm = (bbox_preds['dir_res_norm' + suffix] * dir_res_norm = (bbox_preds['dir_res_norm' + suffix] *
heading_label_one_hot).sum(dim=-1) heading_label_one_hot).sum(dim=-1)
dir_res_loss = self.dir_res_loss( dir_res_loss = self.loss_dir_res(
dir_res_norm, dir_res_targets, weight=box_loss_weights) dir_res_norm, dir_res_targets, weight=box_loss_weights)
# calculate size class loss # calculate size class loss
size_class_loss = self.size_class_loss( size_class_loss = self.loss_size_class(
bbox_preds['size_class' + suffix].transpose(2, 1), bbox_preds['size_class' + suffix].transpose(2, 1),
size_class_targets, size_class_targets,
weight=box_loss_weights) weight=box_loss_weights)
...@@ -631,13 +700,13 @@ class H3DBboxHead(BaseModule): ...@@ -631,13 +700,13 @@ class H3DBboxHead(BaseModule):
one_hot_size_targets_expand).sum(dim=2) one_hot_size_targets_expand).sum(dim=2)
box_loss_weights_expand = box_loss_weights.unsqueeze(-1).repeat( box_loss_weights_expand = box_loss_weights.unsqueeze(-1).repeat(
1, 1, 3) 1, 1, 3)
size_res_loss = self.size_res_loss( size_res_loss = self.loss_size_res(
size_residual_norm, size_residual_norm,
size_res_targets, size_res_targets,
weight=box_loss_weights_expand) weight=box_loss_weights_expand)
# calculate semantic loss # calculate semantic loss
semantic_loss = self.semantic_loss( semantic_loss = self.loss_semantic(
bbox_preds['sem_scores' + suffix].transpose(2, 1), bbox_preds['sem_scores' + suffix].transpose(2, 1),
mask_targets, mask_targets,
weight=box_loss_weights) weight=box_loss_weights)
...@@ -653,91 +722,93 @@ class H3DBboxHead(BaseModule): ...@@ -653,91 +722,93 @@ class H3DBboxHead(BaseModule):
return losses return losses
def get_targets(self, def get_targets(
points, self,
gt_bboxes_3d, points,
gt_labels_3d, feats_dict: Optional[dict] = None,
pts_semantic_mask=None, batch_gt_instances_3d: Optional[List[InstanceData]] = None,
pts_instance_mask=None, ):
bbox_preds=None): """Generate targets of vote head.
"""Generate targets of proposal module.
Args: Args:
points (list[torch.Tensor]): Points of each batch. points (list[torch.Tensor]): Points of each batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth feats_dict (dict, optional): Predictions of previous
bboxes of each batch. components. Defaults to None.
gt_labels_3d (list[torch.Tensor]): Labels of each batch. batch_gt_instances_3d (list[:obj:`InstanceData`], optional):
pts_semantic_mask (list[torch.Tensor]): Point-wise semantic Batch of gt_instances. It usually includes
label of each batch. ``bboxes_3d`` and ``labels_3d`` attributes.
pts_instance_mask (list[torch.Tensor]): Point-wise instance
label of each batch.
bbox_preds (torch.Tensor): Bounding box predictions of vote head.
Returns: Returns:
tuple[torch.Tensor]: Targets of proposal module. tuple[torch.Tensor]: Targets of vote head.
""" """
# find empty example # find empty example
valid_gt_masks = list() valid_gt_masks = list()
gt_num = list() gt_num = list()
for index in range(len(gt_labels_3d)): batch_gt_labels_3d = [
if len(gt_labels_3d[index]) == 0: gt_instances_3d.labels_3d
fake_box = gt_bboxes_3d[index].tensor.new_zeros( for gt_instances_3d in batch_gt_instances_3d
1, gt_bboxes_3d[index].tensor.shape[-1]) ]
gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box) batch_gt_bboxes_3d = [
gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1) gt_instances_3d.bboxes_3d
valid_gt_masks.append(gt_labels_3d[index].new_zeros(1)) for gt_instances_3d in batch_gt_instances_3d
]
for index in range(len(batch_gt_labels_3d)):
if len(batch_gt_labels_3d[index]) == 0:
fake_box = batch_gt_bboxes_3d[index].tensor.new_zeros(
1, batch_gt_bboxes_3d[index].tensor.shape[-1])
batch_gt_bboxes_3d[index] = batch_gt_bboxes_3d[index].new_box(
fake_box)
batch_gt_labels_3d[index] = batch_gt_labels_3d[
index].new_zeros(1)
valid_gt_masks.append(batch_gt_labels_3d[index].new_zeros(1))
gt_num.append(1) gt_num.append(1)
else: else:
valid_gt_masks.append(gt_labels_3d[index].new_ones( valid_gt_masks.append(batch_gt_labels_3d[index].new_ones(
gt_labels_3d[index].shape)) batch_gt_labels_3d[index].shape))
gt_num.append(gt_labels_3d[index].shape[0]) gt_num.append(batch_gt_labels_3d[index].shape[0])
if pts_semantic_mask is None:
pts_semantic_mask = [None for i in range(len(gt_labels_3d))]
pts_instance_mask = [None for i in range(len(gt_labels_3d))]
aggregated_points = [ aggregated_points = [
bbox_preds['aggregated_points'][i] feats_dict['aggregated_points'][i]
for i in range(len(gt_labels_3d)) for i in range(len(batch_gt_labels_3d))
] ]
surface_center_pred = [ surface_center_pred = [
bbox_preds['surface_center_pred'][i] feats_dict['surface_center_pred'][i]
for i in range(len(gt_labels_3d)) for i in range(len(batch_gt_labels_3d))
] ]
line_center_pred = [ line_center_pred = [
bbox_preds['pred_line_center'][i] feats_dict['pred_line_center'][i]
for i in range(len(gt_labels_3d)) for i in range(len(batch_gt_labels_3d))
] ]
surface_center_object = [ surface_center_object = [
bbox_preds['surface_center_object'][i] feats_dict['surface_center_object'][i]
for i in range(len(gt_labels_3d)) for i in range(len(batch_gt_labels_3d))
] ]
line_center_object = [ line_center_object = [
bbox_preds['line_center_object'][i] feats_dict['line_center_object'][i]
for i in range(len(gt_labels_3d)) for i in range(len(batch_gt_labels_3d))
] ]
surface_sem_pred = [ surface_sem_pred = [
bbox_preds['surface_sem_pred'][i] feats_dict['surface_sem_pred'][i]
for i in range(len(gt_labels_3d)) for i in range(len(batch_gt_labels_3d))
] ]
line_sem_pred = [ line_sem_pred = [
bbox_preds['sem_cls_scores_line'][i] feats_dict['sem_cls_scores_line'][i]
for i in range(len(gt_labels_3d)) for i in range(len(batch_gt_labels_3d))
] ]
(cues_objectness_label, cues_sem_label, proposal_objectness_label, (cues_objectness_label, cues_sem_label, proposal_objectness_label,
cues_mask, cues_match_mask, proposal_objectness_mask, cues_mask, cues_match_mask, proposal_objectness_mask,
cues_matching_label, obj_surface_line_center) = multi_apply( cues_matching_label, obj_surface_line_center) = multi_apply(
self.get_targets_single, points, gt_bboxes_3d, gt_labels_3d, self._get_targets_single, points, batch_gt_bboxes_3d,
pts_semantic_mask, pts_instance_mask, aggregated_points, batch_gt_labels_3d, aggregated_points, surface_center_pred,
surface_center_pred, line_center_pred, surface_center_object, line_center_pred, surface_center_object, line_center_object,
line_center_object, surface_sem_pred, line_sem_pred) surface_sem_pred, line_sem_pred)
cues_objectness_label = torch.stack(cues_objectness_label) cues_objectness_label = torch.stack(cues_objectness_label)
cues_sem_label = torch.stack(cues_sem_label) cues_sem_label = torch.stack(cues_sem_label)
...@@ -753,19 +824,17 @@ class H3DBboxHead(BaseModule): ...@@ -753,19 +824,17 @@ class H3DBboxHead(BaseModule):
proposal_objectness_mask, cues_matching_label, proposal_objectness_mask, cues_matching_label,
obj_surface_line_center) obj_surface_line_center)
def get_targets_single(self, def _get_targets_single(self,
points, points: Tensor,
gt_bboxes_3d, gt_bboxes_3d: BaseInstance3DBoxes,
gt_labels_3d, gt_labels_3d: Tensor,
pts_semantic_mask=None, aggregated_points: Optional[Tensor] = None,
pts_instance_mask=None, pred_surface_center: Optional[Tensor] = None,
aggregated_points=None, pred_line_center: Optional[Tensor] = None,
pred_surface_center=None, pred_obj_surface_center: Optional[Tensor] = None,
pred_line_center=None, pred_obj_line_center: Optional[Tensor] = None,
pred_obj_surface_center=None, pred_surface_sem: Optional[Tensor] = None,
pred_obj_line_center=None, pred_line_sem: Optional[Tensor] = None):
pred_surface_sem=None,
pred_line_sem=None):
"""Generate targets for primitive cues for single batch. """Generate targets for primitive cues for single batch.
Args: Args:
...@@ -773,10 +842,6 @@ class H3DBboxHead(BaseModule): ...@@ -773,10 +842,6 @@ class H3DBboxHead(BaseModule):
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth
boxes of each batch. boxes of each batch.
gt_labels_3d (torch.Tensor): Labels of each batch. gt_labels_3d (torch.Tensor): Labels of each batch.
pts_semantic_mask (torch.Tensor): Point-wise semantic
label of each batch.
pts_instance_mask (torch.Tensor): Point-wise instance
label of each batch.
aggregated_points (torch.Tensor): Aggregated points from aggregated_points (torch.Tensor): Aggregated points from
vote aggregation layer. vote aggregation layer.
pred_surface_center (torch.Tensor): Prediction of surface center. pred_surface_center (torch.Tensor): Prediction of surface center.
...@@ -847,12 +912,10 @@ class H3DBboxHead(BaseModule): ...@@ -847,12 +912,10 @@ class H3DBboxHead(BaseModule):
euclidean_dist_line = torch.sqrt(dist_line.squeeze(0) + 1e-6) euclidean_dist_line = torch.sqrt(dist_line.squeeze(0) + 1e-6)
objectness_label_surface = euclidean_dist_line.new_zeros( objectness_label_surface = euclidean_dist_line.new_zeros(
num_proposals * 6, dtype=torch.long) num_proposals * 6, dtype=torch.long)
objectness_mask_surface = euclidean_dist_line.new_zeros(num_proposals *
6)
objectness_label_line = euclidean_dist_line.new_zeros( objectness_label_line = euclidean_dist_line.new_zeros(
num_proposals * 12, dtype=torch.long) num_proposals * 12, dtype=torch.long)
objectness_mask_line = euclidean_dist_line.new_zeros(num_proposals *
12)
objectness_label_surface_sem = euclidean_dist_line.new_zeros( objectness_label_surface_sem = euclidean_dist_line.new_zeros(
num_proposals * 6, dtype=torch.long) num_proposals * 6, dtype=torch.long)
objectness_label_line_sem = euclidean_dist_line.new_zeros( objectness_label_line_sem = euclidean_dist_line.new_zeros(
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdet3d.core.bbox import bbox3d2result from typing import Dict, List
from mmengine import InstanceData
from torch import Tensor
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from ...core import Det3DDataSample
from .base_3droi_head import Base3DRoIHead from .base_3droi_head import Base3DRoIHead
...@@ -16,17 +21,15 @@ class H3DRoIHead(Base3DRoIHead): ...@@ -16,17 +21,15 @@ class H3DRoIHead(Base3DRoIHead):
""" """
def __init__(self, def __init__(self,
primitive_list, primitive_list: List[dict],
bbox_head=None, bbox_head: dict = None,
train_cfg=None, train_cfg: dict = None,
test_cfg=None, test_cfg: dict = None,
pretrained=None, init_cfg: dict = None):
init_cfg=None):
super(H3DRoIHead, self).__init__( super(H3DRoIHead, self).__init__(
bbox_head=bbox_head, bbox_head=bbox_head,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
pretrained=pretrained,
init_cfg=init_cfg) init_cfg=init_cfg)
# Primitive module # Primitive module
assert len(primitive_list) == 3 assert len(primitive_list) == 3
...@@ -39,8 +42,14 @@ class H3DRoIHead(Base3DRoIHead): ...@@ -39,8 +42,14 @@ class H3DRoIHead(Base3DRoIHead):
one.""" one."""
pass pass
def init_bbox_head(self, bbox_head): def init_bbox_head(self, dummy_args, bbox_head):
"""Initialize box head.""" """Initialize box head.
Args:
dummy_args (optional): Just to compatible with
the interface in base class
bbox_head (dict): Config for bbox head.
"""
bbox_head['train_cfg'] = self.train_cfg bbox_head['train_cfg'] = self.train_cfg
bbox_head['test_cfg'] = self.test_cfg bbox_head['test_cfg'] = self.test_cfg
self.bbox_head = MODELS.build(bbox_head) self.bbox_head = MODELS.build(bbox_head)
...@@ -49,111 +58,73 @@ class H3DRoIHead(Base3DRoIHead): ...@@ -49,111 +58,73 @@ class H3DRoIHead(Base3DRoIHead):
"""Initialize assigner and sampler.""" """Initialize assigner and sampler."""
pass pass
def forward_train(self, def loss(self, points: List[Tensor], feats_dict: dict,
feats_dict, batch_data_samples: List[Det3DDataSample], **kwargs):
img_metas,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask,
pts_instance_mask,
gt_bboxes_ignore=None):
"""Training forward function of PartAggregationROIHead. """Training forward function of PartAggregationROIHead.
Args: Args:
feats_dict (dict): Contains features from the first stage. points (list[torch.Tensor]): Point cloud of each sample.
img_metas (list[dict]): Contain pcd and img's meta info. feats_dict (dict): Dict of feature.
points (list[torch.Tensor]): Input points. batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth Samples. It usually includes information such as
bboxes of each sample. `gt_instance_3d`.
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
pts_semantic_mask (list[torch.Tensor]): Point-wise
semantic mask.
pts_instance_mask (list[torch.Tensor]): Point-wise
instance mask.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding boxes to ignore.
Returns: Returns:
dict: losses from each head. dict: losses from each head.
""" """
losses = dict() losses = dict()
sample_mod = self.train_cfg.sample_mod primitive_loss_inputs = (points, feats_dict, batch_data_samples)
assert sample_mod in ['vote', 'seed', 'random'] # note the feats_dict would be added new key and value in each head.
result_z = self.primitive_z(feats_dict, sample_mod)
feats_dict.update(result_z)
result_xy = self.primitive_xy(feats_dict, sample_mod)
feats_dict.update(result_xy)
result_line = self.primitive_line(feats_dict, sample_mod)
feats_dict.update(result_line)
primitive_loss_inputs = (feats_dict, points, gt_bboxes_3d,
gt_labels_3d, pts_semantic_mask,
pts_instance_mask, img_metas,
gt_bboxes_ignore)
loss_z = self.primitive_z.loss(*primitive_loss_inputs) loss_z = self.primitive_z.loss(*primitive_loss_inputs)
losses.update(loss_z)
loss_xy = self.primitive_xy.loss(*primitive_loss_inputs) loss_xy = self.primitive_xy.loss(*primitive_loss_inputs)
losses.update(loss_xy)
loss_line = self.primitive_line.loss(*primitive_loss_inputs) loss_line = self.primitive_line.loss(*primitive_loss_inputs)
losses.update(loss_z)
losses.update(loss_xy)
losses.update(loss_line) losses.update(loss_line)
targets = feats_dict.pop('targets') targets = feats_dict.pop('targets')
bbox_results = self.bbox_head(feats_dict, sample_mod) bbox_loss = self.bbox_head.loss(
points,
feats_dict.update(bbox_results) feats_dict,
bbox_loss = self.bbox_head.loss(feats_dict, points, gt_bboxes_3d, rpn_targets=targets,
gt_labels_3d, pts_semantic_mask, batch_data_samples=batch_data_samples)
pts_instance_mask, img_metas, targets,
gt_bboxes_ignore)
losses.update(bbox_loss) losses.update(bbox_loss)
return losses return losses
def simple_test(self, feats_dict, img_metas, points, rescale=False): def predict(self,
"""Simple testing forward function of PartAggregationROIHead. points: List[Tensor],
feats_dict: Dict[str, Tensor],
Note: batch_data_samples: List[Det3DDataSample],
This function assumes that the batch size is 1 suffix='_optimized',
**kwargs) -> List[InstanceData]:
"""
Args: Args:
feats_dict (dict): Contains features from the first stage. points (list[tensor]): Point clouds of multiple samples.
img_metas (list[dict]): Contain pcd and img's meta info. feats_dict (dict): Features from FPN or backbone..
points (torch.Tensor): Input points. batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
rescale (bool): Whether to rescale results. Samples. It usually includes meta information of data.
Returns: Returns:
dict: Bbox results of one frame. list[:obj:`InstanceData`]: List of processed predictions. Each
InstanceData contains 3d Bounding boxes and corresponding
scores and labels.
""" """
sample_mod = self.test_cfg.sample_mod
assert sample_mod in ['vote', 'seed', 'random']
result_z = self.primitive_z(feats_dict, sample_mod) result_z = self.primitive_z(feats_dict)
feats_dict.update(result_z) feats_dict.update(result_z)
result_xy = self.primitive_xy(feats_dict, sample_mod) result_xy = self.primitive_xy(feats_dict)
feats_dict.update(result_xy) feats_dict.update(result_xy)
result_line = self.primitive_line(feats_dict, sample_mod) result_line = self.primitive_line(feats_dict)
feats_dict.update(result_line) feats_dict.update(result_line)
bbox_preds = self.bbox_head(feats_dict, sample_mod) bbox_preds = self.bbox_head(feats_dict)
feats_dict.update(bbox_preds) feats_dict.update(bbox_preds)
bbox_list = self.bbox_head.get_bboxes( results_list = self.bbox_head.predict(
points, points, feats_dict, batch_data_samples, suffix=suffix)
feats_dict,
img_metas, return results_list
rescale=rescale,
suffix='_optimized')
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.ops import furthest_point_sample from mmcv.ops import furthest_point_sample
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmengine import InstanceData
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.models.builder import build_loss from mmdet3d.core import Det3DDataSample
from mmdet3d.models.model_utils import VoteModule from mmdet3d.models.model_utils import VoteModule
from mmdet3d.ops import build_sa_module from mmdet3d.ops import build_sa_module
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
...@@ -40,24 +43,25 @@ class PrimitiveHead(BaseModule): ...@@ -40,24 +43,25 @@ class PrimitiveHead(BaseModule):
""" """
def __init__(self, def __init__(self,
num_dims, num_dims: int,
num_classes, num_classes: int,
primitive_mode, primitive_mode: str,
train_cfg=None, train_cfg: dict = None,
test_cfg=None, test_cfg: dict = None,
vote_module_cfg=None, vote_module_cfg: dict = None,
vote_aggregation_cfg=None, vote_aggregation_cfg: dict = None,
feat_channels=(128, 128), feat_channels: tuple = (128, 128),
upper_thresh=100.0, upper_thresh: float = 100.0,
surface_thresh=0.5, surface_thresh: float = 0.5,
conv_cfg=dict(type='Conv1d'), conv_cfg: dict = dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg: dict = dict(type='BN1d'),
objectness_loss=None, objectness_loss: dict = None,
center_loss=None, center_loss: dict = None,
semantic_reg_loss=None, semantic_reg_loss: dict = None,
semantic_cls_loss=None, semantic_cls_loss: dict = None,
init_cfg=None): init_cfg: dict = None):
super(PrimitiveHead, self).__init__(init_cfg=init_cfg) super(PrimitiveHead, self).__init__(init_cfg=init_cfg)
# bounding boxes centers, face centers and edge centers
assert primitive_mode in ['z', 'xy', 'line'] assert primitive_mode in ['z', 'xy', 'line']
# The dimension of primitive semantic information. # The dimension of primitive semantic information.
self.num_dims = num_dims self.num_dims = num_dims
...@@ -70,10 +74,10 @@ class PrimitiveHead(BaseModule): ...@@ -70,10 +74,10 @@ class PrimitiveHead(BaseModule):
self.upper_thresh = upper_thresh self.upper_thresh = upper_thresh
self.surface_thresh = surface_thresh self.surface_thresh = surface_thresh
self.objectness_loss = build_loss(objectness_loss) self.loss_objectness = MODELS.build(objectness_loss)
self.center_loss = build_loss(center_loss) self.loss_center = MODELS.build(center_loss)
self.semantic_reg_loss = build_loss(semantic_reg_loss) self.loss_semantic_reg = MODELS.build(semantic_reg_loss)
self.semantic_cls_loss = build_loss(semantic_cls_loss) self.loss_semantic_cls = MODELS.build(semantic_cls_loss)
assert vote_aggregation_cfg['mlp_channels'][0] == vote_module_cfg[ assert vote_aggregation_cfg['mlp_channels'][0] == vote_module_cfg[
'in_channels'] 'in_channels']
...@@ -114,18 +118,26 @@ class PrimitiveHead(BaseModule): ...@@ -114,18 +118,26 @@ class PrimitiveHead(BaseModule):
self.conv_pred.add_module('conv_out', self.conv_pred.add_module('conv_out',
nn.Conv1d(prev_channel, conv_out_channel, 1)) nn.Conv1d(prev_channel, conv_out_channel, 1))
def forward(self, feats_dict, sample_mod): @property
def sample_mode(self):
if self.training:
sample_mode = self.train_cfg.sample_mode
else:
sample_mode = self.test_cfg.sample_mode
assert sample_mode in ['vote', 'seed', 'random']
return sample_mode
def forward(self, feats_dict):
"""Forward pass. """Forward pass.
Args: Args:
feats_dict (dict): Feature dict from backbone. feats_dict (dict): Feature dict from backbone.
sample_mod (str): Sample mode for vote aggregation layer.
valid modes are "vote", "seed" and "random".
Returns: Returns:
dict: Predictions of primitive head. dict: Predictions of primitive head.
""" """
assert sample_mod in ['vote', 'seed', 'random'] sample_mode = self.sample_mode
seed_points = feats_dict['fp_xyz_net0'][-1] seed_points = feats_dict['fp_xyz_net0'][-1]
seed_features = feats_dict['hd_feature'] seed_features = feats_dict['hd_feature']
...@@ -143,14 +155,14 @@ class PrimitiveHead(BaseModule): ...@@ -143,14 +155,14 @@ class PrimitiveHead(BaseModule):
results['vote_features_' + self.primitive_mode] = vote_features results['vote_features_' + self.primitive_mode] = vote_features
# 2. aggregate vote_points # 2. aggregate vote_points
if sample_mod == 'vote': if sample_mode == 'vote':
# use fps in vote_aggregation # use fps in vote_aggregation
sample_indices = None sample_indices = None
elif sample_mod == 'seed': elif sample_mode == 'seed':
# FPS on seed and choose the votes corresponding to the seeds # FPS on seed and choose the votes corresponding to the seeds
sample_indices = furthest_point_sample(seed_points, sample_indices = furthest_point_sample(seed_points,
self.num_proposal) self.num_proposal)
elif sample_mod == 'random': elif sample_mode == 'random':
# Random sampling from the votes # Random sampling from the votes
batch_size, num_seed = seed_points.shape[:2] batch_size, num_seed = seed_points.shape[:2]
sample_indices = torch.randint( sample_indices = torch.randint(
...@@ -185,63 +197,103 @@ class PrimitiveHead(BaseModule): ...@@ -185,63 +197,103 @@ class PrimitiveHead(BaseModule):
results['pred_' + self.primitive_mode + '_center'] = center results['pred_' + self.primitive_mode + '_center'] = center
return results return results
def loss(self, def loss(self, points: List[torch.Tensor], feats_dict: Dict[str,
bbox_preds, torch.Tensor],
points, batch_data_samples: List[Det3DDataSample], **kwargs) -> dict:
gt_bboxes_3d, """
gt_labels_3d, Args:
pts_semantic_mask=None, points (list[tensor]): Points cloud of multiple samples.
pts_instance_mask=None, feats_dict (dict): Predictions from backbone or FPN.
img_metas=None, batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
gt_bboxes_ignore=None): contains the meta information of each sample and
corresponding annotations.
Returns:
dict: A dictionary of loss components.
"""
preds = self(feats_dict)
feats_dict.update(preds)
batch_gt_instance_3d = []
batch_gt_instances_ignore = []
batch_input_metas = []
batch_pts_semantic_mask = []
batch_pts_instance_mask = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
batch_gt_instance_3d.append(data_sample.gt_instances_3d)
batch_gt_instances_ignore.append(
data_sample.get('ignored_instances', None))
batch_pts_semantic_mask.append(
data_sample.gt_pts_seg.get('pts_semantic_mask', None))
batch_pts_instance_mask.append(
data_sample.gt_pts_seg.get('pts_instance_mask', None))
loss_inputs = (points, feats_dict, batch_gt_instance_3d)
losses = self.loss_by_feat(
*loss_inputs,
batch_pts_semantic_mask=batch_pts_semantic_mask,
batch_pts_instance_mask=batch_pts_instance_mask,
batch_gt_instances_ignore=batch_gt_instances_ignore,
)
return losses
def loss_by_feat(
self,
points: List[torch.Tensor],
feats_dict: dict,
batch_gt_instances_3d: List[InstanceData],
batch_pts_semantic_mask: Optional[List[torch.Tensor]] = None,
batch_pts_instance_mask: Optional[List[torch.Tensor]] = None,
**kwargs):
"""Compute loss. """Compute loss.
Args: Args:
bbox_preds (dict): Predictions from forward of primitive head.
points (list[torch.Tensor]): Input points. points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth feats_dict (dict): Predictions of previous modules.
bboxes of each sample. batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_labels_3d (list[torch.Tensor]): Labels of each sample. gt_instances. It usually includes ``bboxes`` and ``labels``
pts_semantic_mask (list[torch.Tensor]): Point-wise attributes.
semantic mask. batch_pts_semantic_mask (list[tensor]): Semantic mask
pts_instance_mask (list[torch.Tensor]): Point-wise of points cloud. Defaults to None.
instance mask. batch_pts_semantic_mask (list[tensor]): Instance mask
img_metas (list[dict]): Contain pcd and img's meta info. of points cloud. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor]): Specify batch_input_metas (list[dict]): Contain pcd and img's meta info.
which bounding. ret_target (bool): Return targets or not. Defaults to False.
Returns: Returns:
dict: Losses of Primitive Head. dict: Losses of Primitive Head.
""" """
targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, targets = self.get_targets(points, feats_dict, batch_gt_instances_3d,
bbox_preds) batch_pts_semantic_mask,
batch_pts_instance_mask)
(point_mask, point_offset, gt_primitive_center, gt_primitive_semantic, (point_mask, point_offset, gt_primitive_center, gt_primitive_semantic,
gt_sem_cls_label, gt_primitive_mask) = targets gt_sem_cls_label, gt_primitive_mask) = targets
losses = {} losses = {}
# Compute the loss of primitive existence flag # Compute the loss of primitive existence flag
pred_flag = bbox_preds['pred_flag_' + self.primitive_mode] pred_flag = feats_dict['pred_flag_' + self.primitive_mode]
flag_loss = self.objectness_loss(pred_flag, gt_primitive_mask.long()) flag_loss = self.loss_objectness(pred_flag, gt_primitive_mask.long())
losses['flag_loss_' + self.primitive_mode] = flag_loss losses['flag_loss_' + self.primitive_mode] = flag_loss
# calculate vote loss # calculate vote loss
vote_loss = self.vote_module.get_loss( vote_loss = self.vote_module.get_loss(
bbox_preds['seed_points'], feats_dict['seed_points'],
bbox_preds['vote_' + self.primitive_mode], feats_dict['vote_' + self.primitive_mode],
bbox_preds['seed_indices'], point_mask, point_offset) feats_dict['seed_indices'], point_mask, point_offset)
losses['vote_loss_' + self.primitive_mode] = vote_loss losses['vote_loss_' + self.primitive_mode] = vote_loss
num_proposal = bbox_preds['aggregated_points_' + num_proposal = feats_dict['aggregated_points_' +
self.primitive_mode].shape[1] self.primitive_mode].shape[1]
primitive_center = bbox_preds['center_' + self.primitive_mode] primitive_center = feats_dict['center_' + self.primitive_mode]
if self.primitive_mode != 'line': if self.primitive_mode != 'line':
primitive_semantic = bbox_preds['size_residuals_' + primitive_semantic = feats_dict['size_residuals_' +
self.primitive_mode].contiguous() self.primitive_mode].contiguous()
else: else:
primitive_semantic = None primitive_semantic = None
semancitc_scores = bbox_preds['sem_cls_scores_' + semancitc_scores = feats_dict['sem_cls_scores_' +
self.primitive_mode].transpose(2, 1) self.primitive_mode].transpose(2, 1)
gt_primitive_mask = gt_primitive_mask / \ gt_primitive_mask = gt_primitive_mask / \
...@@ -256,44 +308,61 @@ class PrimitiveHead(BaseModule): ...@@ -256,44 +308,61 @@ class PrimitiveHead(BaseModule):
return losses return losses
def get_targets(self, def get_targets(
points, self,
gt_bboxes_3d, points,
gt_labels_3d, bbox_preds: Optional[dict] = None,
pts_semantic_mask=None, batch_gt_instances_3d: List[InstanceData] = None,
pts_instance_mask=None, batch_pts_semantic_mask: List[torch.Tensor] = None,
bbox_preds=None): batch_pts_instance_mask: List[torch.Tensor] = None,
):
"""Generate targets of primitive head. """Generate targets of primitive head.
Args: Args:
points (list[torch.Tensor]): Points of each batch. points (list[torch.Tensor]): Points of each batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth bbox_preds (torch.Tensor): Bounding box predictions of
bboxes of each batch. primitive head.
gt_labels_3d (list[torch.Tensor]): Labels of each batch. batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
pts_semantic_mask (list[torch.Tensor]): Point-wise semantic gt_instances. It usually includes ``bboxes_3d`` and
label of each batch. ``labels_3d`` attributes.
pts_instance_mask (list[torch.Tensor]): Point-wise instance batch_pts_semantic_mask (list[tensor]): Semantic gt mask for
label of each batch. multiple images.
bbox_preds (dict): Predictions from forward of primitive head. batch_pts_instance_mask (list[tensor]): Instance gt mask for
multiple images.
Returns: Returns:
tuple[torch.Tensor]: Targets of primitive head. tuple[torch.Tensor]: Targets of primitive head.
""" """
for index in range(len(gt_labels_3d)): batch_gt_labels_3d = [
if len(gt_labels_3d[index]) == 0: gt_instances_3d.labels_3d
fake_box = gt_bboxes_3d[index].tensor.new_zeros( for gt_instances_3d in batch_gt_instances_3d
1, gt_bboxes_3d[index].tensor.shape[-1]) ]
gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box) batch_gt_bboxes_3d = [
gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1) gt_instances_3d.bboxes_3d
for gt_instances_3d in batch_gt_instances_3d
if pts_semantic_mask is None: ]
pts_semantic_mask = [None for i in range(len(gt_labels_3d))] for index in range(len(batch_gt_labels_3d)):
pts_instance_mask = [None for i in range(len(gt_labels_3d))] if len(batch_gt_labels_3d[index]) == 0:
fake_box = batch_gt_bboxes_3d[index].tensor.new_zeros(
1, batch_gt_bboxes_3d[index].tensor.shape[-1])
batch_gt_bboxes_3d[index] = batch_gt_bboxes_3d[index].new_box(
fake_box)
batch_gt_labels_3d[index] = batch_gt_labels_3d[
index].new_zeros(1)
if batch_pts_semantic_mask is None:
batch_pts_semantic_mask = [
None for _ in range(len(batch_gt_labels_3d))
]
batch_pts_instance_mask = [
None for _ in range(len(batch_gt_labels_3d))
]
(point_mask, point_sem, (point_mask, point_sem,
point_offset) = multi_apply(self.get_targets_single, points, point_offset) = multi_apply(self.get_targets_single, points,
gt_bboxes_3d, gt_labels_3d, batch_gt_bboxes_3d, batch_gt_labels_3d,
pts_semantic_mask, pts_instance_mask) batch_pts_semantic_mask,
batch_pts_instance_mask)
point_mask = torch.stack(point_mask) point_mask = torch.stack(point_mask)
point_sem = torch.stack(point_sem) point_sem = torch.stack(point_sem)
...@@ -759,7 +828,7 @@ class PrimitiveHead(BaseModule): ...@@ -759,7 +828,7 @@ class PrimitiveHead(BaseModule):
vote_xyz_reshape = primitive_center.view(batch_size * num_proposal, -1, vote_xyz_reshape = primitive_center.view(batch_size * num_proposal, -1,
3) 3)
center_loss = self.center_loss( center_loss = self.loss_center(
vote_xyz_reshape, vote_xyz_reshape,
gt_primitive_center, gt_primitive_center,
dst_weight=gt_primitive_mask.view(batch_size * num_proposal, 1))[1] dst_weight=gt_primitive_mask.view(batch_size * num_proposal, 1))[1]
...@@ -767,7 +836,7 @@ class PrimitiveHead(BaseModule): ...@@ -767,7 +836,7 @@ class PrimitiveHead(BaseModule):
if self.primitive_mode != 'line': if self.primitive_mode != 'line':
size_xyz_reshape = primitive_semantic.view( size_xyz_reshape = primitive_semantic.view(
batch_size * num_proposal, -1, self.num_dims).contiguous() batch_size * num_proposal, -1, self.num_dims).contiguous()
size_loss = self.semantic_reg_loss( size_loss = self.loss_semantic_reg(
size_xyz_reshape, size_xyz_reshape,
gt_primitive_semantic, gt_primitive_semantic,
dst_weight=gt_primitive_mask.view(batch_size * num_proposal, dst_weight=gt_primitive_mask.view(batch_size * num_proposal,
...@@ -776,7 +845,7 @@ class PrimitiveHead(BaseModule): ...@@ -776,7 +845,7 @@ class PrimitiveHead(BaseModule):
size_loss = center_loss.new_tensor(0.0) size_loss = center_loss.new_tensor(0.0)
# Semantic cls loss # Semantic cls loss
sem_cls_loss = self.semantic_cls_loss( sem_cls_loss = self.loss_semantic_cls(
semantic_scores, gt_sem_cls_label, weight=gt_primitive_mask) semantic_scores, gt_sem_cls_label, weight=gt_primitive_mask)
return center_loss, size_loss, sem_cls_loss return center_loss, size_loss, sem_cls_loss
......
import unittest
import torch
from mmengine import DefaultScope
from mmdet3d.registry import MODELS
from tests.utils.model_utils import (_create_detector_inputs,
_get_detector_cfg, _setup_seed)
class TestH3D(unittest.TestCase):
def test_h3dnet(self):
import mmdet3d.models
assert hasattr(mmdet3d.models, 'H3DNet')
DefaultScope.get_instance('test_H3DNet', scope_name='mmdet3d')
_setup_seed(0)
voxel_net_cfg = _get_detector_cfg(
'h3dnet/h3dnet_3x8_scannet-3d-18class.py')
model = MODELS.build(voxel_net_cfg)
num_gt_instance = 5
data = [
_create_detector_inputs(
num_gt_instance=num_gt_instance,
points_feat_dim=4,
bboxes_3d_type='depth',
with_pts_semantic_mask=True,
with_pts_instance_mask=True)
]
if torch.cuda.is_available():
model = model.cuda()
# test simple_test
with torch.no_grad():
batch_inputs, data_samples = model.data_preprocessor(
data, True)
results = model.forward(
batch_inputs, data_samples, mode='predict')
self.assertEqual(len(results), len(data))
self.assertIn('bboxes_3d', results[0].pred_instances_3d)
self.assertIn('scores_3d', results[0].pred_instances_3d)
self.assertIn('labels_3d', results[0].pred_instances_3d)
# save the memory
with torch.no_grad():
losses = model.forward(batch_inputs, data_samples, mode='loss')
self.assertGreater(losses['vote_loss'], 0)
self.assertGreater(losses['objectness_loss'], 0)
self.assertGreater(losses['center_loss'], 0)
...@@ -7,7 +7,8 @@ import numpy as np ...@@ -7,7 +7,8 @@ import numpy as np
import torch import torch
from mmengine import InstanceData from mmengine import InstanceData
from mmdet3d.core import Det3DDataSample, LiDARInstance3DBoxes, PointData from mmdet3d.core import (CameraInstance3DBoxes, DepthInstance3DBoxes,
Det3DDataSample, LiDARInstance3DBoxes, PointData)
def _setup_seed(seed): def _setup_seed(seed):
...@@ -71,19 +72,24 @@ def _get_detector_cfg(fname): ...@@ -71,19 +72,24 @@ def _get_detector_cfg(fname):
return model return model
def _create_detector_inputs( def _create_detector_inputs(seed=0,
seed=0, with_points=True,
with_points=True, with_img=False,
with_img=False, num_gt_instance=20,
num_gt_instance=20, num_points=10,
num_points=10, points_feat_dim=4,
points_feat_dim=4, num_classes=3,
num_classes=3, gt_bboxes_dim=7,
gt_bboxes_dim=7, with_pts_semantic_mask=False,
with_pts_semantic_mask=False, with_pts_instance_mask=False,
with_pts_instance_mask=False, bboxes_3d_type='lidar'):
):
_setup_seed(seed) _setup_seed(seed)
assert bboxes_3d_type in ('lidar', 'depth', 'cam')
bbox_3d_class = {
'lidar': LiDARInstance3DBoxes,
'depth': DepthInstance3DBoxes,
'cam': CameraInstance3DBoxes
}
if with_points: if with_points:
points = torch.rand([num_points, points_feat_dim]) points = torch.rand([num_points, points_feat_dim])
else: else:
...@@ -93,12 +99,13 @@ def _create_detector_inputs( ...@@ -93,12 +99,13 @@ def _create_detector_inputs(
else: else:
img = None img = None
inputs_dict = dict(img=img, points=points) inputs_dict = dict(img=img, points=points)
gt_instance_3d = InstanceData() gt_instance_3d = InstanceData()
gt_instance_3d.bboxes_3d = LiDARInstance3DBoxes( gt_instance_3d.bboxes_3d = bbox_3d_class[bboxes_3d_type](
torch.rand([num_gt_instance, gt_bboxes_dim]), box_dim=gt_bboxes_dim) torch.rand([num_gt_instance, gt_bboxes_dim]), box_dim=gt_bboxes_dim)
gt_instance_3d.labels_3d = torch.randint(0, num_classes, [num_gt_instance]) gt_instance_3d.labels_3d = torch.randint(0, num_classes, [num_gt_instance])
data_sample = Det3DDataSample( data_sample = Det3DDataSample(
metainfo=dict(box_type_3d=LiDARInstance3DBoxes)) metainfo=dict(box_type_3d=bbox_3d_class[bboxes_3d_type]))
data_sample.gt_instances_3d = gt_instance_3d data_sample.gt_instances_3d = gt_instance_3d
data_sample.gt_pts_seg = PointData() data_sample.gt_pts_seg = PointData()
if with_pts_instance_mask: if with_pts_instance_mask:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment