Unverified Commit 16e8d143 authored by encore-zhou's avatar encore-zhou Committed by GitHub
Browse files

add h3d head (#58)

* add h3d head

* add h3d roi head

* update docstring of h3d roi head

* reconstruct h3d head

* remove unused code

* modify h3d bbox head

* add h3dnet init files

* modify h3d bbox head

* add depth_box3d unittest

* update h3d head

* add h3dnet benchmark

* update docstring in vote_head

* resovle primitive conflict
parent 299d666a
primitive_z_cfg = dict(
type='PrimitiveHead',
num_dims=2,
num_classes=18,
primitive_mode='z',
upper_thresh=100.0,
surface_thresh=0.5,
vote_moudule_cfg=dict(
in_channels=256,
vote_per_seed=1,
gt_per_seed=1,
conv_channels=(256, 256),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
norm_feats=True,
vote_loss=dict(
type='ChamferDistance',
mode='l1',
reduction='none',
loss_dst_weight=10.0)),
vote_aggregation_cfg=dict(
num_point=1024,
radius=0.3,
num_sample=16,
mlp_channels=[256, 128, 128, 128],
use_xyz=True,
normalize_xyz=True),
feat_channels=(128, 128),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.4, 0.6],
reduction='mean',
loss_weight=30.0),
center_loss=dict(
type='ChamferDistance',
mode='l1',
reduction='sum',
loss_src_weight=0.5,
loss_dst_weight=0.5),
semantic_reg_loss=dict(
type='ChamferDistance',
mode='l1',
reduction='sum',
loss_src_weight=0.5,
loss_dst_weight=0.5),
semantic_cls_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
train_cfg=dict(
dist_thresh=0.2,
var_thresh=1e-2,
lower_thresh=1e-6,
num_point=100,
num_point_line=10,
line_thresh=0.2))
primitive_xy_cfg = dict(
type='PrimitiveHead',
num_dims=1,
num_classes=18,
primitive_mode='xy',
upper_thresh=100.0,
surface_thresh=0.5,
vote_moudule_cfg=dict(
in_channels=256,
vote_per_seed=1,
gt_per_seed=1,
conv_channels=(256, 256),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
norm_feats=True,
vote_loss=dict(
type='ChamferDistance',
mode='l1',
reduction='none',
loss_dst_weight=10.0)),
vote_aggregation_cfg=dict(
num_point=1024,
radius=0.3,
num_sample=16,
mlp_channels=[256, 128, 128, 128],
use_xyz=True,
normalize_xyz=True),
feat_channels=(128, 128),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.4, 0.6],
reduction='mean',
loss_weight=30.0),
center_loss=dict(
type='ChamferDistance',
mode='l1',
reduction='sum',
loss_src_weight=0.5,
loss_dst_weight=0.5),
semantic_reg_loss=dict(
type='ChamferDistance',
mode='l1',
reduction='sum',
loss_src_weight=0.5,
loss_dst_weight=0.5),
semantic_cls_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
train_cfg=dict(
dist_thresh=0.2,
var_thresh=1e-2,
lower_thresh=1e-6,
num_point=100,
num_point_line=10,
line_thresh=0.2))
primitive_line_cfg = dict(
type='PrimitiveHead',
num_dims=0,
num_classes=18,
primitive_mode='line',
upper_thresh=100.0,
surface_thresh=0.5,
vote_moudule_cfg=dict(
in_channels=256,
vote_per_seed=1,
gt_per_seed=1,
conv_channels=(256, 256),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
norm_feats=True,
vote_loss=dict(
type='ChamferDistance',
mode='l1',
reduction='none',
loss_dst_weight=10.0)),
vote_aggregation_cfg=dict(
num_point=1024,
radius=0.3,
num_sample=16,
mlp_channels=[256, 128, 128, 128],
use_xyz=True,
normalize_xyz=True),
feat_channels=(128, 128),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.4, 0.6],
reduction='mean',
loss_weight=30.0),
center_loss=dict(
type='ChamferDistance',
mode='l1',
reduction='sum',
loss_src_weight=1.0,
loss_dst_weight=1.0),
semantic_reg_loss=dict(
type='ChamferDistance',
mode='l1',
reduction='sum',
loss_src_weight=1.0,
loss_dst_weight=1.0),
semantic_cls_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=2.0),
train_cfg=dict(
dist_thresh=0.2,
var_thresh=1e-2,
lower_thresh=1e-6,
num_point=100,
num_point_line=10,
line_thresh=0.2))
proposal_module_cfg = dict(
suface_matching_cfg=dict(
num_point=256 * 6,
radius=0.5,
num_sample=32,
mlp_channels=[128 + 6, 128, 64, 32],
use_xyz=True,
normalize_xyz=True),
line_matching_cfg=dict(
num_point=256 * 12,
radius=0.5,
num_sample=32,
mlp_channels=[128 + 12, 128, 64, 32],
use_xyz=True,
normalize_xyz=True),
primitive_refine_channels=[128, 128, 128],
upper_thresh=100.0,
surface_thresh=0.5,
line_thresh=0.5,
train_cfg=dict(
far_threshold=0.6,
near_threshold=0.3,
mask_surface_threshold=0.3,
label_surface_threshold=0.3,
mask_line_threshold=0.3,
label_line_threshold=0.3),
cues_objectness_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.3, 0.7],
reduction='mean',
loss_weight=5.0),
cues_semantic_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.3, 0.7],
reduction='mean',
loss_weight=5.0),
proposal_objectness_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.2, 0.8],
reduction='none',
loss_weight=5.0),
primitive_center_loss=dict(
type='MSELoss', reduction='none', loss_weight=1.0))
model = dict(
type='H3DNet',
backbone=dict(
type='MultiBackbone',
num_streams=4,
suffixes=['net0', 'net1', 'net2', 'net3'],
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.01),
act_cfg=dict(type='ReLU'),
backbones=dict(
type='PointNet2SASSG',
in_channels=4,
num_points=(2048, 1024, 512, 256),
radius=(0.2, 0.4, 0.8, 1.2),
num_samples=(64, 32, 16, 16),
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)),
fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'),
pool_mod='max')),
rpn_head=dict(
type='VoteHead',
vote_moudule_cfg=dict(
in_channels=256,
vote_per_seed=1,
gt_per_seed=3,
conv_channels=(256, 256),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
norm_feats=True,
vote_loss=dict(
type='ChamferDistance',
mode='l1',
reduction='none',
loss_dst_weight=10.0)),
vote_aggregation_cfg=dict(
num_point=256,
radius=0.3,
num_sample=16,
mlp_channels=[256, 128, 128, 128],
use_xyz=True,
normalize_xyz=True),
feat_channels=(128, 128),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.2, 0.8],
reduction='sum',
loss_weight=5.0),
center_loss=dict(
type='ChamferDistance',
mode='l2',
reduction='sum',
loss_src_weight=10.0,
loss_dst_weight=10.0),
dir_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
dir_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0),
size_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
size_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0),
semantic_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0)),
roi_head=dict(
type='H3DRoIHead',
primitive_list=[primitive_z_cfg, primitive_xy_cfg, primitive_line_cfg],
bbox_head=dict(
type='H3DBboxHead',
gt_per_seed=3,
num_proposal=256,
proposal_module_cfg=proposal_module_cfg,
feat_channels=(128, 128),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.2, 0.8],
reduction='sum',
loss_weight=5.0),
center_loss=dict(
type='ChamferDistance',
mode='l2',
reduction='sum',
loss_src_weight=10.0,
loss_dst_weight=10.0),
dir_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1),
dir_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0),
size_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1),
size_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0),
semantic_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1))))
# model training and testing settings
train_cfg = dict(
rpn=dict(pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mod='vote'),
rpn_proposal=dict(use_nms=False),
rcnn=dict(pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mod='vote'))
test_cfg = dict(
rpn=dict(
sample_mod='seed',
nms_thr=0.25,
score_thr=0.05,
per_class_proposal=True,
use_nms=False),
rcnn=dict(
sample_mod='seed',
nms_thr=0.25,
score_thr=0.05,
per_class_proposal=True))
# H3DNet: 3D Object Detection Using Hybrid Geometric Primitives
## Introduction
We implement H3DNet and provide the result and checkpoints on ScanNet datasets.
```
@inproceedings{zhang2020h3dnet,
author = {Zhang, Zaiwei and Sun, Bo and Yang, Haitao and Huang, Qixing},
title = {H3DNet: 3D Object Detection Using Hybrid Geometric Primitives},
booktitle = {Proceedings of the European Conference on Computer Vision},
year = {2020}
}
```
## Results
### ScanNet
| Backbone | Lr schd | Mem (GB) | Inf time (fps) | AP@0.25 |AP@0.5| Download |
| :---------: | :-----: | :------: | :------------: | :----: |:----: | :------: |
| [MultiBackbone](./h3dnet_scannet-3d-18class.py) | 3x |7.9||66.43|48.01|[model](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection3d/v0.1.0_models/votenet/votenet_8x8_scannet-3d-18class/votenet_8x8_scannet-3d-18class_20200620_230238-2cea9c3a.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection3d/v0.1.0_models/votenet/votenet_8x8_scannet-3d-18class/votenet_8x8_scannet-3d-18class_20200620_230238.log.json)|
_base_ = [
'../_base_/datasets/scannet-3d-18class.py', '../_base_/models/h3dnet.py',
'../_base_/schedules/schedule_3x.py', '../_base_/default_runtime.py'
]
# model settings
model = dict(
rpn_head=dict(
num_classes=18,
bbox_coder=dict(
type='PartialBinBasedBBoxCoder',
num_sizes=18,
num_dir_bins=24,
with_rot=False,
mean_sizes=[[0.76966727, 0.8116021, 0.92573744],
[1.876858, 1.8425595, 1.1931566],
[0.61328, 0.6148609, 0.7182701],
[1.3955007, 1.5121545, 0.83443564],
[0.97949594, 1.0675149, 0.6329687],
[0.531663, 0.5955577, 1.7500148],
[0.9624706, 0.72462326, 1.1481868],
[0.83221924, 1.0490936, 1.6875663],
[0.21132214, 0.4206159, 0.5372846],
[1.4440073, 1.8970833, 0.26985747],
[1.0294262, 1.4040797, 0.87554324],
[1.3766412, 0.65521795, 1.6813129],
[0.6650819, 0.71111923, 1.298853],
[0.41999173, 0.37906948, 1.7513971],
[0.59359556, 0.5912492, 0.73919016],
[0.50867593, 0.50656086, 0.30136237],
[1.1511526, 1.0546296, 0.49706793],
[0.47535285, 0.49249494, 0.5802117]])),
roi_head=dict(
bbox_head=dict(
num_classes=18,
bbox_coder=dict(
type='PartialBinBasedBBoxCoder',
num_sizes=18,
num_dir_bins=24,
with_rot=False,
mean_sizes=[[0.76966727, 0.8116021, 0.92573744],
[1.876858, 1.8425595, 1.1931566],
[0.61328, 0.6148609, 0.7182701],
[1.3955007, 1.5121545, 0.83443564],
[0.97949594, 1.0675149, 0.6329687],
[0.531663, 0.5955577, 1.7500148],
[0.9624706, 0.72462326, 1.1481868],
[0.83221924, 1.0490936, 1.6875663],
[0.21132214, 0.4206159, 0.5372846],
[1.4440073, 1.8970833, 0.26985747],
[1.0294262, 1.4040797, 0.87554324],
[1.3766412, 0.65521795, 1.6813129],
[0.6650819, 0.71111923, 1.298853],
[0.41999173, 0.37906948, 1.7513971],
[0.59359556, 0.5912492, 0.73919016],
[0.50867593, 0.50656086, 0.30136237],
[1.1511526, 1.0546296, 0.49706793],
[0.47535285, 0.49249494, 0.5802117]]))))
data = dict(samples_per_gpu=3, workers_per_gpu=2)
# optimizer
# yapf:disable
log_config = dict(
interval=30,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')
])
# yapf:enable
...@@ -55,7 +55,7 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder): ...@@ -55,7 +55,7 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder):
return (center_target, size_class_target, size_res_target, return (center_target, size_class_target, size_res_target,
dir_class_target, dir_res_target) dir_class_target, dir_res_target)
def decode(self, bbox_out): def decode(self, bbox_out, suffix=''):
"""Decode predicted parts to bbox3d. """Decode predicted parts to bbox3d.
Args: Args:
...@@ -66,17 +66,18 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder): ...@@ -66,17 +66,18 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder):
- dir_res: predicted bbox direction residual. - dir_res: predicted bbox direction residual.
- size_class: predicted bbox size class. - size_class: predicted bbox size class.
- size_res: predicted bbox size residual. - size_res: predicted bbox size residual.
suffix (str): Decode predictions with specific suffix.
Returns: Returns:
torch.Tensor: Decoded bbox3d with shape (batch, n, 7). torch.Tensor: Decoded bbox3d with shape (batch, n, 7).
""" """
center = bbox_out['center'] center = bbox_out['center' + suffix]
batch_size, num_proposal = center.shape[:2] batch_size, num_proposal = center.shape[:2]
# decode heading angle # decode heading angle
if self.with_rot: if self.with_rot:
dir_class = torch.argmax(bbox_out['dir_class'], -1) dir_class = torch.argmax(bbox_out['dir_class' + suffix], -1)
dir_res = torch.gather(bbox_out['dir_res'], 2, dir_res = torch.gather(bbox_out['dir_res' + suffix], 2,
dir_class.unsqueeze(-1)) dir_class.unsqueeze(-1))
dir_res.squeeze_(2) dir_res.squeeze_(2)
dir_angle = self.class2angle(dir_class, dir_res).reshape( dir_angle = self.class2angle(dir_class, dir_res).reshape(
...@@ -85,8 +86,9 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder): ...@@ -85,8 +86,9 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder):
dir_angle = center.new_zeros(batch_size, num_proposal, 1) dir_angle = center.new_zeros(batch_size, num_proposal, 1)
# decode bbox size # decode bbox size
size_class = torch.argmax(bbox_out['size_class'], -1, keepdim=True) size_class = torch.argmax(
size_res = torch.gather(bbox_out['size_res'], 2, bbox_out['size_class' + suffix], -1, keepdim=True)
size_res = torch.gather(bbox_out['size_res' + suffix], 2,
size_class.unsqueeze(-1).repeat(1, 1, 1, 3)) size_class.unsqueeze(-1).repeat(1, 1, 1, 3))
mean_sizes = center.new_tensor(self.mean_sizes) mean_sizes = center.new_tensor(self.mean_sizes)
size_base = torch.index_select(mean_sizes, 0, size_class.reshape(-1)) size_base = torch.index_select(mean_sizes, 0, size_class.reshape(-1))
......
...@@ -251,3 +251,52 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes): ...@@ -251,3 +251,52 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
box_idxs_of_pts = points_in_boxes_batch(points_lidar, boxes_lidar) box_idxs_of_pts = points_in_boxes_batch(points_lidar, boxes_lidar)
return box_idxs_of_pts.squeeze(0) return box_idxs_of_pts.squeeze(0)
def get_surface_line_center(self):
"""Compute surface and line center of bounding boxes.
Returns:
torch.Tensor: Surface and line center of bounding boxes.
"""
obj_size = self.dims
center = self.gravity_center
batch_size = center.shape[0]
rot_sin = torch.sin(-self.yaw)
rot_cos = torch.cos(-self.yaw)
rot_mat_T = self.yaw.new_zeros(tuple(list(self.yaw.shape) + [3, 3]))
rot_mat_T[..., 0, 0] = rot_cos
rot_mat_T[..., 0, 1] = -rot_sin
rot_mat_T[..., 1, 0] = rot_sin
rot_mat_T[..., 1, 1] = rot_cos
rot_mat_T[..., 2, 2] = 1
# Get the object surface center
offset = obj_size.new_tensor([[0, 0, 1], [0, 0, -1], [0, 1, 0],
[0, -1, 0], [1, 0, 0], [-1, 0, 0]])
offset = offset.view(1, 6, 3) / 2
surface_3d = (offset * obj_size.view(batch_size, 1, 3).repeat(
1, 6, 1)).transpose(0, 1).reshape(-1, 3)
# Get the object line center
offset = obj_size.new_tensor([[1, 0, 1], [-1, 0, 1], [0, 1, 1],
[0, -1, 1], [1, 0, -1], [-1, 0, -1],
[0, 1, -1], [0, -1, -1], [1, 1, 0],
[1, -1, 0], [-1, 1, 0], [-1, -1, 0]])
offset = offset.view(1, 12, 3) / 2
line_3d = (offset *
obj_size.view(batch_size, 1, 3).repeat(1, 12, 1)).transpose(
0, 1).reshape(-1, 3)
surface_rot = rot_mat_T.repeat(6, 1, 1)
surface_3d = torch.matmul(
surface_3d.unsqueeze(-2), surface_rot.transpose(2, 1)).squeeze(-2)
surface_center = center.repeat(6, 1) + surface_3d
line_rot = rot_mat_T.repeat(12, 1, 1)
line_3d = torch.matmul(
line_3d.unsqueeze(-2), line_rot.transpose(2, 1)).squeeze(-2)
line_center = center.repeat(12, 1) + line_3d
return surface_center, line_center
...@@ -164,6 +164,7 @@ class VoteHead(nn.Module): ...@@ -164,6 +164,7 @@ class VoteHead(nn.Module):
sample_indices) sample_indices)
aggregated_points, features, aggregated_indices = vote_aggregation_ret aggregated_points, features, aggregated_indices = vote_aggregation_ret
results['aggregated_points'] = aggregated_points results['aggregated_points'] = aggregated_points
results['aggregated_features'] = features
results['aggregated_indices'] = aggregated_indices results['aggregated_indices'] = aggregated_indices
# 3. predict bbox and score # 3. predict bbox and score
...@@ -183,7 +184,8 @@ class VoteHead(nn.Module): ...@@ -183,7 +184,8 @@ class VoteHead(nn.Module):
pts_semantic_mask=None, pts_semantic_mask=None,
pts_instance_mask=None, pts_instance_mask=None,
img_metas=None, img_metas=None,
gt_bboxes_ignore=None): gt_bboxes_ignore=None,
ret_target=False):
"""Compute loss. """Compute loss.
Args: Args:
...@@ -199,6 +201,7 @@ class VoteHead(nn.Module): ...@@ -199,6 +201,7 @@ class VoteHead(nn.Module):
img_metas (list[dict]): Contain pcd and img's meta info. img_metas (list[dict]): Contain pcd and img's meta info.
gt_bboxes_ignore (None | list[torch.Tensor]): Specify gt_bboxes_ignore (None | list[torch.Tensor]): Specify
which bounding. which bounding.
ret_target (Bool): Return targets or not.
Returns: Returns:
dict: Losses of Votenet. dict: Losses of Votenet.
...@@ -283,6 +286,10 @@ class VoteHead(nn.Module): ...@@ -283,6 +286,10 @@ class VoteHead(nn.Module):
dir_res_loss=dir_res_loss, dir_res_loss=dir_res_loss,
size_class_loss=size_class_loss, size_class_loss=size_class_loss,
size_res_loss=size_res_loss) size_res_loss=size_res_loss)
if ret_target:
losses['targets'] = targets
return losses return losses
def get_targets(self, def get_targets(self,
...@@ -494,7 +501,12 @@ class VoteHead(nn.Module): ...@@ -494,7 +501,12 @@ class VoteHead(nn.Module):
dir_class_targets, dir_res_targets, center_targets, dir_class_targets, dir_res_targets, center_targets,
mask_targets.long(), objectness_targets, objectness_masks) mask_targets.long(), objectness_targets, objectness_masks)
def get_bboxes(self, points, bbox_preds, input_metas, rescale=False): def get_bboxes(self,
points,
bbox_preds,
input_metas,
rescale=False,
use_nms=True):
"""Generate bboxes from vote head predictions. """Generate bboxes from vote head predictions.
Args: Args:
...@@ -502,6 +514,8 @@ class VoteHead(nn.Module): ...@@ -502,6 +514,8 @@ class VoteHead(nn.Module):
bbox_preds (dict): Predictions from vote head. bbox_preds (dict): Predictions from vote head.
input_metas (list[dict]): Point cloud and image's meta info. input_metas (list[dict]): Point cloud and image's meta info.
rescale (bool): Whether to rescale bboxes. rescale (bool): Whether to rescale bboxes.
use_nms (bool): Whether to apply NMS, skip nms postprocessing
while using vote head in rpn stage.
Returns: Returns:
list[tuple[torch.Tensor]]: Bounding boxes, scores and labels. list[tuple[torch.Tensor]]: Bounding boxes, scores and labels.
...@@ -511,11 +525,13 @@ class VoteHead(nn.Module): ...@@ -511,11 +525,13 @@ class VoteHead(nn.Module):
sem_scores = F.softmax(bbox_preds['sem_scores'], dim=-1) sem_scores = F.softmax(bbox_preds['sem_scores'], dim=-1)
bbox3d = self.bbox_coder.decode(bbox_preds) bbox3d = self.bbox_coder.decode(bbox_preds)
if use_nms:
batch_size = bbox3d.shape[0] batch_size = bbox3d.shape[0]
results = list() results = list()
for b in range(batch_size): for b in range(batch_size):
bbox_selected, score_selected, labels = self.multiclass_nms_single( bbox_selected, score_selected, labels = \
obj_scores[b], sem_scores[b], bbox3d[b], points[b, ..., :3], self.multiclass_nms_single(obj_scores[b], sem_scores[b],
bbox3d[b], points[b, ..., :3],
input_metas[b]) input_metas[b])
bbox = input_metas[b]['box_type_3d']( bbox = input_metas[b]['box_type_3d'](
bbox_selected, bbox_selected,
...@@ -524,6 +540,8 @@ class VoteHead(nn.Module): ...@@ -524,6 +540,8 @@ class VoteHead(nn.Module):
results.append((bbox, score_selected, labels)) results.append((bbox, score_selected, labels))
return results return results
else:
return bbox3d
def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points, def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points,
input_meta): input_meta):
......
from .base import Base3DDetector from .base import Base3DDetector
from .dynamic_voxelnet import DynamicVoxelNet from .dynamic_voxelnet import DynamicVoxelNet
from .h3dnet import H3DNet
from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN
from .mvx_two_stage import MVXTwoStageDetector from .mvx_two_stage import MVXTwoStageDetector
from .parta2 import PartA2 from .parta2 import PartA2
...@@ -8,5 +9,5 @@ from .voxelnet import VoxelNet ...@@ -8,5 +9,5 @@ from .voxelnet import VoxelNet
__all__ = [ __all__ = [
'Base3DDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXTwoStageDetector', 'Base3DDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXTwoStageDetector',
'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet' 'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet'
] ]
import torch
from mmdet3d.core import merge_aug_bboxes_3d
from mmdet.models import DETECTORS
from .two_stage import TwoStage3DDetector
@DETECTORS.register_module()
class H3DNet(TwoStage3DDetector):
r"""H3DNet model.
Please refer to the `paper <https://arxiv.org/abs/2006.05682>`_
"""
def __init__(self,
backbone,
neck=None,
rpn_head=None,
roi_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(H3DNet, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
def forward_train(self,
points,
img_metas,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
gt_bboxes_ignore=None):
"""Forward of training.
Args:
points (list[torch.Tensor]): Points of each batch.
img_metas (list): Image metas.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
pts_semantic_mask (None | list[torch.Tensor]): point-wise semantic
label of each batch.
pts_instance_mask (None | list[torch.Tensor]): point-wise instance
label of each batch.
gt_bboxes_ignore (None | list[torch.Tensor]): Specify
which bounding.
Returns:
dict: Losses.
"""
points_cat = torch.stack(points)
feats_dict = self.extract_feat(points_cat)
feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]]
feats_dict['fp_features'] = [feats_dict['hd_feature']]
feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]]
losses = dict()
if self.with_rpn:
rpn_outs = self.rpn_head(feats_dict, self.train_cfg.rpn.sample_mod)
feats_dict.update(rpn_outs)
rpn_loss_inputs = (points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, img_metas)
rpn_losses = self.rpn_head.loss(
rpn_outs,
*rpn_loss_inputs,
gt_bboxes_ignore=gt_bboxes_ignore,
ret_target=True)
feats_dict['targets'] = rpn_losses.pop('targets')
losses.update(rpn_losses)
# Generate rpn proposals
proposal_cfg = self.train_cfg.get('rpn_proposal',
self.test_cfg.rpn)
proposal_inputs = (points, rpn_outs, img_metas)
proposal_list = self.rpn_head.get_bboxes(
*proposal_inputs, use_nms=proposal_cfg.use_nms)
feats_dict['proposal_list'] = proposal_list
else:
raise NotImplementedError
roi_losses = self.roi_head.forward_train(feats_dict, img_metas, points,
gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask,
pts_instance_mask,
gt_bboxes_ignore)
losses.update(roi_losses)
return losses
def simple_test(self, points, img_metas, imgs=None, rescale=False):
"""Forward of testing.
Args:
points (list[torch.Tensor]): Points of each sample.
img_metas (list): Image metas.
rescale (bool): Whether to rescale results.
Returns:
list: Predicted 3d boxes.
"""
points_cat = torch.stack(points)
feats_dict = self.extract_feat(points_cat)
feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]]
feats_dict['fp_features'] = [feats_dict['hd_feature']]
feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]]
if self.with_rpn:
proposal_cfg = self.test_cfg.rpn
rpn_outs = self.rpn_head(feats_dict, proposal_cfg.sample_mod)
feats_dict.update(rpn_outs)
# Generate rpn proposals
proposal_list = self.rpn_head.get_bboxes(
points, rpn_outs, img_metas, use_nms=proposal_cfg.use_nms)
feats_dict['proposal_list'] = proposal_list
else:
raise NotImplementedError
return self.roi_head.simple_test(
feats_dict, img_metas, points_cat, rescale=rescale)
def aug_test(self, points, img_metas, imgs=None, rescale=False):
"""Test with augmentation."""
points_cat = [torch.stack(pts) for pts in points]
feats_dict = self.extract_feats(points_cat, img_metas)
for feat_dict in feats_dict:
feat_dict['fp_xyz'] = [feat_dict['fp_xyz_net0'][-1]]
feat_dict['fp_features'] = [feat_dict['hd_feature']]
feat_dict['fp_indices'] = [feat_dict['fp_indices_net0'][-1]]
# only support aug_test for one sample
aug_bboxes = []
for feat_dict, pts_cat, img_meta in zip(feats_dict, points_cat,
img_metas):
if self.with_rpn:
proposal_cfg = self.test_cfg.rpn
rpn_outs = self.rpn_head(feat_dict, proposal_cfg.sample_mod)
feat_dict.update(rpn_outs)
# Generate rpn proposals
proposal_list = self.rpn_head.get_bboxes(
points, rpn_outs, img_metas, use_nms=proposal_cfg.use_nms)
feat_dict['proposal_list'] = proposal_list
else:
raise NotImplementedError
bbox_results = self.roi_head.simple_test(
feat_dict,
self.test_cfg.rcnn.sample_mod,
img_meta,
pts_cat,
rescale=rescale)
aug_bboxes.append(bbox_results)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
self.bbox_head.test_cfg)
return merged_bboxes
def extract_feats(self, points, img_metas):
"""Extract features of multiple samples."""
return [
self.extract_feat(pts, img_meta)
for pts, img_meta in zip(points, img_metas)
]
from .base_3droi_head import Base3DRoIHead from .base_3droi_head import Base3DRoIHead
from .bbox_heads import PartA2BboxHead from .bbox_heads import PartA2BboxHead
from .mask_heads import PointwiseSemanticHead from .h3d_roi_head import H3DRoIHead
from .mask_heads import PointwiseSemanticHead, PrimitiveHead
from .part_aggregation_roi_head import PartAggregationROIHead from .part_aggregation_roi_head import PartAggregationROIHead
from .roi_extractors import Single3DRoIAwareExtractor, SingleRoIExtractor from .roi_extractors import Single3DRoIAwareExtractor, SingleRoIExtractor
__all__ = [ __all__ = [
'Base3DRoIHead', 'PartAggregationROIHead', 'PointwiseSemanticHead', 'Base3DRoIHead', 'PartAggregationROIHead', 'PointwiseSemanticHead',
'Single3DRoIAwareExtractor', 'PartA2BboxHead', 'SingleRoIExtractor' 'Single3DRoIAwareExtractor', 'PartA2BboxHead', 'SingleRoIExtractor',
'H3DRoIHead', 'PrimitiveHead'
] ]
...@@ -2,9 +2,11 @@ from mmdet.models.roi_heads.bbox_heads import (BBoxHead, ConvFCBBoxHead, ...@@ -2,9 +2,11 @@ from mmdet.models.roi_heads.bbox_heads import (BBoxHead, ConvFCBBoxHead,
DoubleConvFCBBoxHead, DoubleConvFCBBoxHead,
Shared2FCBBoxHead, Shared2FCBBoxHead,
Shared4Conv1FCBBoxHead) Shared4Conv1FCBBoxHead)
from .h3d_bbox_head import H3DBboxHead
from .parta2_bbox_head import PartA2BboxHead from .parta2_bbox_head import PartA2BboxHead
__all__ = [ __all__ = [
'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead', 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead',
'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'PartA2BboxHead' 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'H3DBboxHead',
'PartA2BboxHead'
] ]
import torch
from mmcv.cnn import ConvModule
from torch import nn as nn
from torch.nn import functional as F
from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import build_loss
from mmdet3d.models.losses import chamfer_distance
from mmdet3d.ops import PointSAModule
from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS
@HEADS.register_module()
class H3DBboxHead(nn.Module):
r"""Bbox head of `H3DNet <https://arxiv.org/abs/2006.05682>`_.
Args:
num_classes (int): The number of classes.
suface_matching_cfg (dict): Config for suface primitive matching.
line_matching_cfg (dict): Config for line primitive matching.
bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
decoding boxes.
train_cfg (dict): Config for training.
test_cfg (dict): Config for testing.
gt_per_seed (int): Number of ground truth votes generated
from each seed point.
num_proposal (int): Number of proposal votes generated.
feat_channels (tuple[int]): Convolution channels of
prediction layer.
primitive_feat_refine_streams (int): The number of mlps to
refine primitive feature.
primitive_refine_channels (tuple[int]): Convolution channels of
prediction layer.
upper_thresh (float): Threshold for line matching.
surface_thresh (float): Threshold for suface matching.
line_thresh (float): Threshold for line matching.
conv_cfg (dict): Config of convolution in prediction layer.
norm_cfg (dict): Config of BN in prediction layer.
objectness_loss (dict): Config of objectness loss.
center_loss (dict): Config of center loss.
dir_class_loss (dict): Config of direction classification loss.
dir_res_loss (dict): Config of direction residual regression loss.
size_class_loss (dict): Config of size classification loss.
size_res_loss (dict): Config of size residual regression loss.
semantic_loss (dict): Config of point-wise semantic segmentation loss.
cues_objectness_loss (dict): Config of cues objectness loss.
cues_semantic_loss (dict): Config of cues semantic loss.
proposal_objectness_loss (dict): Config of proposal objectness
loss.
primitive_center_loss (dict): Config of primitive center regression
loss.
"""
def __init__(self,
num_classes,
suface_matching_cfg,
line_matching_cfg,
bbox_coder,
train_cfg=None,
test_cfg=None,
proposal_module_cfg=None,
gt_per_seed=1,
num_proposal=256,
feat_channels=(128, 128),
primitive_feat_refine_streams=2,
primitive_refine_channels=[128, 128, 128],
upper_thresh=100.0,
surface_thresh=0.5,
line_thresh=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=None,
center_loss=None,
dir_class_loss=None,
dir_res_loss=None,
size_class_loss=None,
size_res_loss=None,
semantic_loss=None,
cues_objectness_loss=None,
cues_semantic_loss=None,
proposal_objectness_loss=None,
primitive_center_loss=None):
super(H3DBboxHead, self).__init__()
self.num_classes = num_classes
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.gt_per_seed = gt_per_seed
self.num_proposal = num_proposal
self.with_angle = bbox_coder['with_rot']
self.upper_thresh = upper_thresh
self.surface_thresh = surface_thresh
self.line_thresh = line_thresh
self.objectness_loss = build_loss(objectness_loss)
self.center_loss = build_loss(center_loss)
self.dir_class_loss = build_loss(dir_class_loss)
self.dir_res_loss = build_loss(dir_res_loss)
self.size_class_loss = build_loss(size_class_loss)
self.size_res_loss = build_loss(size_res_loss)
self.semantic_loss = build_loss(semantic_loss)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_sizes = self.bbox_coder.num_sizes
self.num_dir_bins = self.bbox_coder.num_dir_bins
self.cues_objectness_loss = build_loss(cues_objectness_loss)
self.cues_semantic_loss = build_loss(cues_semantic_loss)
self.proposal_objectness_loss = build_loss(proposal_objectness_loss)
self.primitive_center_loss = build_loss(primitive_center_loss)
assert suface_matching_cfg['mlp_channels'][-1] == \
line_matching_cfg['mlp_channels'][-1]
# surface center matching
self.surface_center_matcher = PointSAModule(**suface_matching_cfg)
# line center matching
self.line_center_matcher = PointSAModule(**line_matching_cfg)
# Compute the matching scores
matching_feat_dims = suface_matching_cfg['mlp_channels'][-1]
self.matching_conv = ConvModule(
matching_feat_dims,
matching_feat_dims,
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
inplace=True)
self.matching_pred = nn.Conv1d(matching_feat_dims, 2, 1)
# Compute the semantic matching scores
self.semantic_matching_conv = ConvModule(
matching_feat_dims,
matching_feat_dims,
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
inplace=True)
self.semantic_matching_pred = nn.Conv1d(matching_feat_dims, 2, 1)
# Surface feature aggregation
self.surface_feats_aggregation = list()
for k in range(primitive_feat_refine_streams):
self.surface_feats_aggregation.append(
ConvModule(
matching_feat_dims,
matching_feat_dims,
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
inplace=True))
self.surface_feats_aggregation = nn.Sequential(
*self.surface_feats_aggregation)
# Line feature aggregation
self.line_feats_aggregation = list()
for k in range(primitive_feat_refine_streams):
self.line_feats_aggregation.append(
ConvModule(
matching_feat_dims,
matching_feat_dims,
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
inplace=True))
self.line_feats_aggregation = nn.Sequential(
*self.line_feats_aggregation)
# surface center(6) + line center(12)
prev_channel = 18 * matching_feat_dims
self.bbox_pred = nn.ModuleList()
for k in range(len(primitive_refine_channels)):
self.bbox_pred.append(
ConvModule(
prev_channel,
primitive_refine_channels[k],
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
inplace=False))
prev_channel = primitive_refine_channels[k]
# Final object detection
# Objectness scores (2), center residual (3),
# heading class+residual (num_heading_bin*2), size class +
# residual(num_size_cluster*4)
conv_out_channel = (2 + 3 + bbox_coder['num_dir_bins'] * 2 +
bbox_coder['num_sizes'] * 4 + self.num_classes)
self.bbox_pred.append(nn.Conv1d(prev_channel, conv_out_channel, 1))
def init_weights(self, pretrained=None):
"""Initialize the weights in detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
pass
def forward(self, feats_dict, sample_mod):
"""Forward pass.
Args:
feats_dict (dict): Feature dict from backbone.
sample_mod (str): Sample mode for vote aggregation layer.
valid modes are "vote", "seed" and "random".
Returns:
dict: Predictions of vote head.
"""
ret_dict = {}
aggregated_points = feats_dict['aggregated_points']
original_feature = feats_dict['aggregated_features']
batch_size = original_feature.shape[0]
object_proposal = original_feature.shape[2]
# Extract surface center, features and semantic predictions
z_center = feats_dict['pred_z_center']
xy_center = feats_dict['pred_xy_center']
z_semantic = feats_dict['sem_cls_scores_z']
xy_semantic = feats_dict['sem_cls_scores_xy']
z_feature = feats_dict['aggregated_features_z']
xy_feature = feats_dict['aggregated_features_xy']
# Extract line points and features
line_center = feats_dict['pred_line_center']
line_feature = feats_dict['aggregated_features_line']
surface_center_pred = torch.cat((z_center, xy_center), dim=1)
ret_dict['surface_center_pred'] = surface_center_pred
ret_dict['surface_sem_pred'] = torch.cat((z_semantic, xy_semantic),
dim=1)
# Extract the surface and line centers of rpn proposals
rpn_proposals = feats_dict['proposal_list']
rpn_proposals_bbox = DepthInstance3DBoxes(
rpn_proposals.reshape(-1, 7).clone(),
box_dim=rpn_proposals.shape[-1],
with_yaw=self.with_angle,
origin=(0.5, 0.5, 0.5))
obj_surface_center, obj_line_center = \
rpn_proposals_bbox.get_surface_line_center()
obj_surface_center = obj_surface_center.reshape(
batch_size, -1, 6, 3).transpose(1, 2).reshape(batch_size, -1, 3)
obj_line_center = obj_line_center.reshape(batch_size, -1, 12,
3).transpose(1, 2).reshape(
batch_size, -1, 3)
ret_dict['surface_center_object'] = obj_surface_center
ret_dict['line_center_object'] = obj_line_center
# aggregate primitive z and xy features to rpn proposals
surface_center_feature_pred = torch.cat((z_feature, xy_feature), dim=2)
surface_center_feature_pred = torch.cat(
(surface_center_feature_pred.new_zeros(
(batch_size, 6, surface_center_feature_pred.shape[2])),
surface_center_feature_pred),
dim=1)
surface_xyz, surface_features, _ = self.surface_center_matcher(
surface_center_pred,
surface_center_feature_pred,
target_xyz=obj_surface_center)
# aggregate primitive line features to rpn proposals
line_feature = torch.cat((line_feature.new_zeros(
(batch_size, 12, line_feature.shape[2])), line_feature),
dim=1)
line_xyz, line_features, _ = self.line_center_matcher(
line_center, line_feature, target_xyz=obj_line_center)
# combine the surface and line features
combine_features = torch.cat((surface_features, line_features), dim=2)
matching_features = self.matching_conv(combine_features)
matching_score = self.matching_pred(matching_features)
ret_dict['matching_score'] = matching_score.transpose(2, 1)
semantic_matching_features = self.semantic_matching_conv(
combine_features)
semantic_matching_score = self.semantic_matching_pred(
semantic_matching_features)
ret_dict['semantic_matching_score'] = \
semantic_matching_score.transpose(2, 1)
surface_features = self.surface_feats_aggregation(surface_features)
line_features = self.line_feats_aggregation(line_features)
# Combine all surface and line features
surface_features = surface_features.view(batch_size, -1,
object_proposal)
line_features = line_features.view(batch_size, -1, object_proposal)
combine_feature = torch.cat((surface_features, line_features), dim=1)
# Final bbox predictions
bbox_predictions = self.bbox_pred[0](combine_feature)
bbox_predictions += original_feature
for conv_module in self.bbox_pred[1:]:
bbox_predictions = conv_module(bbox_predictions)
refine_decode_res = self.bbox_coder.split_pred(bbox_predictions,
aggregated_points)
for key in refine_decode_res.keys():
ret_dict[key + '_optimized'] = refine_decode_res[key]
return ret_dict
def loss(self,
bbox_preds,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
img_metas=None,
rpn_targets=None,
gt_bboxes_ignore=None):
"""Compute loss.
Args:
bbox_preds (dict): Predictions from forward of h3d bbox head.
points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
bboxes of each sample.
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
pts_semantic_mask (None | list[torch.Tensor]): Point-wise
semantic mask.
pts_instance_mask (None | list[torch.Tensor]): Point-wise
instance mask.
img_metas (list[dict]): Contain pcd and img's meta info.
rpn_targets (Tuple) : Targets generated by rpn head.
gt_bboxes_ignore (None | list[torch.Tensor]): Specify
which bounding.
Returns:
dict: Losses of H3dnet.
"""
(vote_targets, vote_target_masks, size_class_targets, size_res_targets,
dir_class_targets, dir_res_targets, center_targets, mask_targets,
valid_gt_masks, objectness_targets, objectness_weights,
box_loss_weights, valid_gt_weights) = rpn_targets
losses = {}
# calculate refined proposal loss
refined_proposal_loss = self.get_proposal_stage_loss(
bbox_preds,
size_class_targets,
size_res_targets,
dir_class_targets,
dir_res_targets,
center_targets,
mask_targets,
objectness_targets,
objectness_weights,
box_loss_weights,
valid_gt_weights,
suffix='_optimized')
for key in refined_proposal_loss.keys():
losses[key + '_optimized'] = refined_proposal_loss[key]
bbox3d_optimized = self.bbox_coder.decode(
bbox_preds, suffix='_optimized')
targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask,
bbox_preds)
(cues_objectness_label, cues_sem_label, proposal_objectness_label,
cues_mask, cues_match_mask, proposal_objectness_mask,
cues_matching_label, obj_surface_line_center) = targets
# match scores for each geometric primitive
objectness_scores = bbox_preds['matching_score']
# match scores for the semantics of primitives
objectness_scores_sem = bbox_preds['semantic_matching_score']
primitive_objectness_loss = self.cues_objectness_loss(
objectness_scores.transpose(2, 1),
cues_objectness_label,
weight=cues_mask,
avg_factor=cues_mask.sum() + 1e-6)
primitive_sem_loss = self.cues_semantic_loss(
objectness_scores_sem.transpose(2, 1),
cues_sem_label,
weight=cues_mask,
avg_factor=cues_mask.sum() + 1e-6)
objectness_scores = bbox_preds['obj_scores_optimized']
objectness_loss_refine = self.proposal_objectness_loss(
objectness_scores.transpose(2, 1), proposal_objectness_label)
primitive_matching_loss = (objectness_loss_refine *
cues_match_mask).sum() / (
cues_match_mask.sum() + 1e-6) * 0.5
primitive_sem_matching_loss = (
objectness_loss_refine * proposal_objectness_mask).sum() / (
proposal_objectness_mask.sum() + 1e-6) * 0.5
# Get the object surface center here
batch_size, object_proposal = bbox3d_optimized.shape[:2]
refined_bbox = DepthInstance3DBoxes(
bbox3d_optimized.reshape(-1, 7).clone(),
box_dim=bbox3d_optimized.shape[-1],
with_yaw=self.with_angle,
origin=(0.5, 0.5, 0.5))
pred_obj_surface_center, pred_obj_line_center = \
refined_bbox.get_surface_line_center()
pred_obj_surface_center = pred_obj_surface_center.reshape(
batch_size, -1, 6, 3).transpose(1, 2).reshape(batch_size, -1, 3)
pred_obj_line_center = pred_obj_line_center.reshape(
batch_size, -1, 12, 3).transpose(1, 2).reshape(batch_size, -1, 3)
pred_surface_line_center = torch.cat(
(pred_obj_surface_center, pred_obj_line_center), 1)
square_dist = self.primitive_center_loss(pred_surface_line_center,
obj_surface_line_center)
match_dist = torch.sqrt(square_dist.sum(dim=-1) + 1e-6)
primitive_centroid_reg_loss = torch.sum(
match_dist * cues_matching_label) / (
cues_matching_label.sum() + 1e-6)
refined_loss = dict(
primitive_objectness_loss=primitive_objectness_loss,
primitive_sem_loss=primitive_sem_loss,
primitive_matching_loss=primitive_matching_loss,
primitive_sem_matching_loss=primitive_sem_matching_loss,
primitive_centroid_reg_loss=primitive_centroid_reg_loss)
losses.update(refined_loss)
return losses
def get_bboxes(self,
points,
bbox_preds,
input_metas,
rescale=False,
suffix=''):
"""Generate bboxes from vote head predictions.
Args:
points (torch.Tensor): Input points.
bbox_preds (dict): Predictions from vote head.
input_metas (list[dict]): Point cloud and image's meta info.
rescale (bool): Whether to rescale bboxes.
Returns:
list[tuple[torch.Tensor]]: Bounding boxes, scores and labels.
"""
# decode boxes
obj_scores = F.softmax(
bbox_preds['obj_scores' + suffix], dim=-1)[..., -1]
sem_scores = F.softmax(bbox_preds['sem_scores'], dim=-1)
prediction_collection = {}
prediction_collection['center'] = bbox_preds['center' + suffix]
prediction_collection['dir_class'] = bbox_preds['dir_class']
prediction_collection['dir_res'] = bbox_preds['dir_res' + suffix]
prediction_collection['size_class'] = bbox_preds['size_class']
prediction_collection['size_res'] = bbox_preds['size_res' + suffix]
bbox3d = self.bbox_coder.decode(prediction_collection)
batch_size = bbox3d.shape[0]
results = list()
for b in range(batch_size):
bbox_selected, score_selected, labels = self.multiclass_nms_single(
obj_scores[b], sem_scores[b], bbox3d[b], points[b, ..., :3],
input_metas[b])
bbox = input_metas[b]['box_type_3d'](
bbox_selected,
box_dim=bbox_selected.shape[-1],
with_yaw=self.bbox_coder.with_rot)
results.append((bbox, score_selected, labels))
return results
def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points,
input_meta):
"""Multi-class nms in single batch.
Args:
obj_scores (torch.Tensor): Objectness score of bounding boxes.
sem_scores (torch.Tensor): semantic class score of bounding boxes.
bbox (torch.Tensor): Predicted bounding boxes.
points (torch.Tensor): Input points.
input_meta (dict): Point cloud and image's meta info.
Returns:
tuple[torch.Tensor]: Bounding boxes, scores and labels.
"""
bbox = input_meta['box_type_3d'](
bbox,
box_dim=bbox.shape[-1],
with_yaw=self.bbox_coder.with_rot,
origin=(0.5, 0.5, 0.5))
box_indices = bbox.points_in_boxes(points)
corner3d = bbox.corners
minmax_box3d = corner3d.new(torch.Size((corner3d.shape[0], 6)))
minmax_box3d[:, :3] = torch.min(corner3d, dim=1)[0]
minmax_box3d[:, 3:] = torch.max(corner3d, dim=1)[0]
nonempty_box_mask = box_indices.T.sum(1) > 5
bbox_classes = torch.argmax(sem_scores, -1)
nms_selected = aligned_3d_nms(minmax_box3d[nonempty_box_mask],
obj_scores[nonempty_box_mask],
bbox_classes[nonempty_box_mask],
self.test_cfg.nms_thr)
# filter empty boxes and boxes with low score
scores_mask = (obj_scores > self.test_cfg.score_thr)
nonempty_box_inds = torch.nonzero(nonempty_box_mask).flatten()
nonempty_mask = torch.zeros_like(bbox_classes).scatter(
0, nonempty_box_inds[nms_selected], 1)
selected = (nonempty_mask.bool() & scores_mask.bool())
if self.test_cfg.per_class_proposal:
bbox_selected, score_selected, labels = [], [], []
for k in range(sem_scores.shape[-1]):
bbox_selected.append(bbox[selected].tensor)
score_selected.append(obj_scores[selected] *
sem_scores[selected][:, k])
labels.append(
torch.zeros_like(bbox_classes[selected]).fill_(k))
bbox_selected = torch.cat(bbox_selected, 0)
score_selected = torch.cat(score_selected, 0)
labels = torch.cat(labels, 0)
else:
bbox_selected = bbox[selected].tensor
score_selected = obj_scores[selected]
labels = bbox_classes[selected]
return bbox_selected, score_selected, labels
def get_proposal_stage_loss(self,
bbox_preds,
size_class_targets,
size_res_targets,
dir_class_targets,
dir_res_targets,
center_targets,
mask_targets,
objectness_targets,
objectness_weights,
box_loss_weights,
valid_gt_weights,
suffix=''):
"""Compute loss for the aggregation module.
Args:
bbox_preds (dict): Predictions from forward of vote head.
size_class_targets (torch.Tensor): Ground truth \
size class of each prediction bounding box.
size_res_targets (torch.Tensor): Ground truth \
size residual of each prediction bounding box.
dir_class_targets (torch.Tensor): Ground truth \
direction class of each prediction bounding box.
dir_res_targets (torch.Tensor): Ground truth \
direction residual of each prediction bounding box.
center_targets (torch.Tensor): Ground truth center \
of each prediction bounding box.
mask_targets (torch.Tensor): Validation of each \
prediction bounding box.
objectness_targets (torch.Tensor): Ground truth \
objectness label of each prediction bounding box.
objectness_weights (torch.Tensor): Weights of objectness \
loss for each prediction bounding box.
box_loss_weights (torch.Tensor): Weights of regression \
loss for each prediction bounding box.
valid_gt_weights (torch.Tensor): Validation of each \
ground truth bounding box.
Returns:
dict: Losses of aggregation module.
"""
# calculate objectness loss
objectness_loss = self.objectness_loss(
bbox_preds['obj_scores' + suffix].transpose(2, 1),
objectness_targets,
weight=objectness_weights)
# calculate center loss
source2target_loss, target2source_loss = self.center_loss(
bbox_preds['center' + suffix],
center_targets,
src_weight=box_loss_weights,
dst_weight=valid_gt_weights)
center_loss = source2target_loss + target2source_loss
# calculate direction class loss
dir_class_loss = self.dir_class_loss(
bbox_preds['dir_class' + suffix].transpose(2, 1),
dir_class_targets,
weight=box_loss_weights)
# calculate direction residual loss
batch_size, proposal_num = size_class_targets.shape[:2]
heading_label_one_hot = dir_class_targets.new_zeros(
(batch_size, proposal_num, self.num_dir_bins))
heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1)
dir_res_norm = (bbox_preds['dir_res_norm' + suffix] *
heading_label_one_hot).sum(dim=-1)
dir_res_loss = self.dir_res_loss(
dir_res_norm, dir_res_targets, weight=box_loss_weights)
# calculate size class loss
size_class_loss = self.size_class_loss(
bbox_preds['size_class' + suffix].transpose(2, 1),
size_class_targets,
weight=box_loss_weights)
# calculate size residual loss
one_hot_size_targets = box_loss_weights.new_zeros(
(batch_size, proposal_num, self.num_sizes))
one_hot_size_targets.scatter_(2, size_class_targets.unsqueeze(-1), 1)
one_hot_size_targets_expand = one_hot_size_targets.unsqueeze(
-1).repeat(1, 1, 1, 3)
size_residual_norm = (bbox_preds['size_res_norm' + suffix] *
one_hot_size_targets_expand).sum(dim=2)
box_loss_weights_expand = box_loss_weights.unsqueeze(-1).repeat(
1, 1, 3)
size_res_loss = self.size_res_loss(
size_residual_norm,
size_res_targets,
weight=box_loss_weights_expand)
# calculate semantic loss
semantic_loss = self.semantic_loss(
bbox_preds['sem_scores' + suffix].transpose(2, 1),
mask_targets,
weight=box_loss_weights)
losses = dict(
objectness_loss=objectness_loss,
semantic_loss=semantic_loss,
center_loss=center_loss,
dir_class_loss=dir_class_loss,
dir_res_loss=dir_res_loss,
size_class_loss=size_class_loss,
size_res_loss=size_res_loss)
return losses
def get_targets(self,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
bbox_preds=None):
"""Generate targets of proposal module.
Args:
points (list[torch.Tensor]): Points of each batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): Labels of each batch.
pts_semantic_mask (None | list[torch.Tensor]): Point-wise semantic
label of each batch.
pts_instance_mask (None | list[torch.Tensor]): Point-wise instance
label of each batch.
bbox_preds (torch.Tensor): Bounding box predictions of vote head.
Returns:
tuple[torch.Tensor]: Targets of proposal module.
"""
# find empty example
valid_gt_masks = list()
gt_num = list()
for index in range(len(gt_labels_3d)):
if len(gt_labels_3d[index]) == 0:
fake_box = gt_bboxes_3d[index].tensor.new_zeros(
1, gt_bboxes_3d[index].tensor.shape[-1])
gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box)
gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1)
valid_gt_masks.append(gt_labels_3d[index].new_zeros(1))
gt_num.append(1)
else:
valid_gt_masks.append(gt_labels_3d[index].new_ones(
gt_labels_3d[index].shape))
gt_num.append(gt_labels_3d[index].shape[0])
if pts_semantic_mask is None:
pts_semantic_mask = [None for i in range(len(gt_labels_3d))]
pts_instance_mask = [None for i in range(len(gt_labels_3d))]
aggregated_points = [
bbox_preds['aggregated_points'][i]
for i in range(len(gt_labels_3d))
]
surface_center_pred = [
bbox_preds['surface_center_pred'][i]
for i in range(len(gt_labels_3d))
]
line_center_pred = [
bbox_preds['pred_line_center'][i]
for i in range(len(gt_labels_3d))
]
surface_center_object = [
bbox_preds['surface_center_object'][i]
for i in range(len(gt_labels_3d))
]
line_center_object = [
bbox_preds['line_center_object'][i]
for i in range(len(gt_labels_3d))
]
surface_sem_pred = [
bbox_preds['surface_sem_pred'][i]
for i in range(len(gt_labels_3d))
]
line_sem_pred = [
bbox_preds['sem_cls_scores_line'][i]
for i in range(len(gt_labels_3d))
]
(cues_objectness_label, cues_sem_label, proposal_objectness_label,
cues_mask, cues_match_mask, proposal_objectness_mask,
cues_matching_label, obj_surface_line_center) = multi_apply(
self.get_targets_single, points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, aggregated_points,
surface_center_pred, line_center_pred, surface_center_object,
line_center_object, surface_sem_pred, line_sem_pred)
cues_objectness_label = torch.stack(cues_objectness_label)
cues_sem_label = torch.stack(cues_sem_label)
proposal_objectness_label = torch.stack(proposal_objectness_label)
cues_mask = torch.stack(cues_mask)
cues_match_mask = torch.stack(cues_match_mask)
proposal_objectness_mask = torch.stack(proposal_objectness_mask)
cues_matching_label = torch.stack(cues_matching_label)
obj_surface_line_center = torch.stack(obj_surface_line_center)
return (cues_objectness_label, cues_sem_label,
proposal_objectness_label, cues_mask, cues_match_mask,
proposal_objectness_mask, cues_matching_label,
obj_surface_line_center)
def get_targets_single(self,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
aggregated_points=None,
pred_surface_center=None,
pred_line_center=None,
pred_obj_surface_center=None,
pred_obj_line_center=None,
pred_surface_sem=None,
pred_line_sem=None):
"""Generate targets for primitive cues for single batch.
Args:
points (torch.Tensor): Points of each batch.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth \
boxes of each batch.
gt_labels_3d (torch.Tensor): Labels of each batch.
pts_semantic_mask (None | torch.Tensor): Point-wise semantic
label of each batch.
pts_instance_mask (None | torch.Tensor): Point-wise instance
label of each batch.
aggregated_points (torch.Tensor): Aggregated points from
vote aggregation layer.
pred_surface_center (torch.Tensor): Prediction of surface center.
pred_line_center (torch.Tensor): Prediction of line center.
pred_obj_surface_center (torch.Tensor): Objectness prediction \
of surface center.
pred_obj_line_center (torch.Tensor): Objectness prediction of \
line center.
pred_surface_sem (torch.Tensor): Semantic prediction of \
surface center.
pred_line_sem (torch.Tensor): Semantic prediction of line center.
Returns:
tuple[torch.Tensor]: Targets for primitive cues.
"""
device = points.device
gt_bboxes_3d = gt_bboxes_3d.to(device)
num_proposals = aggregated_points.shape[0]
gt_center = gt_bboxes_3d.gravity_center
dist1, dist2, ind1, _ = chamfer_distance(
aggregated_points.unsqueeze(0),
gt_center.unsqueeze(0),
reduction='none')
# Set assignment
object_assignment = ind1.squeeze(0)
# Generate objectness label and mask
# objectness_label: 1 if pred object center is within
# self.train_cfg['near_threshold'] of any GT object
# objectness_mask: 0 if pred object center is in gray
# zone (DONOTCARE), 1 otherwise
euclidean_dist1 = torch.sqrt(dist1.squeeze(0) + 1e-6)
proposal_objectness_label = euclidean_dist1.new_zeros(
num_proposals, dtype=torch.long)
proposal_objectness_mask = euclidean_dist1.new_zeros(num_proposals)
gt_sem = gt_labels_3d[object_assignment]
obj_surface_center, obj_line_center = \
gt_bboxes_3d.get_surface_line_center()
obj_surface_center = obj_surface_center.reshape(-1, 6,
3).transpose(0, 1)
obj_line_center = obj_line_center.reshape(-1, 12, 3).transpose(0, 1)
obj_surface_center = obj_surface_center[:, object_assignment].reshape(
1, -1, 3)
obj_line_center = obj_line_center[:,
object_assignment].reshape(1, -1, 3)
surface_sem = torch.argmax(pred_surface_sem, dim=1).float()
line_sem = torch.argmax(pred_line_sem, dim=1).float()
dist_surface, _, surface_ind, _ = chamfer_distance(
obj_surface_center,
pred_surface_center.unsqueeze(0),
reduction='none')
dist_line, _, line_ind, _ = chamfer_distance(
obj_line_center, pred_line_center.unsqueeze(0), reduction='none')
surface_sel = pred_surface_center[surface_ind.squeeze(0)]
line_sel = pred_line_center[line_ind.squeeze(0)]
surface_sel_sem = surface_sem[surface_ind.squeeze(0)]
line_sel_sem = line_sem[line_ind.squeeze(0)]
surface_sel_sem_gt = gt_sem.repeat(6).float()
line_sel_sem_gt = gt_sem.repeat(12).float()
euclidean_dist_surface = torch.sqrt(dist_surface.squeeze(0) + 1e-6)
euclidean_dist_line = torch.sqrt(dist_line.squeeze(0) + 1e-6)
objectness_label_surface = euclidean_dist_line.new_zeros(
num_proposals * 6, dtype=torch.long)
objectness_mask_surface = euclidean_dist_line.new_zeros(num_proposals *
6)
objectness_label_line = euclidean_dist_line.new_zeros(
num_proposals * 12, dtype=torch.long)
objectness_mask_line = euclidean_dist_line.new_zeros(num_proposals *
12)
objectness_label_surface_sem = euclidean_dist_line.new_zeros(
num_proposals * 6, dtype=torch.long)
objectness_label_line_sem = euclidean_dist_line.new_zeros(
num_proposals * 12, dtype=torch.long)
euclidean_dist_obj_surface = torch.sqrt((
(pred_obj_surface_center - surface_sel)**2).sum(dim=-1) + 1e-6)
euclidean_dist_obj_line = torch.sqrt(
torch.sum((pred_obj_line_center - line_sel)**2, dim=-1) + 1e-6)
# Objectness score just with centers
proposal_objectness_label[
euclidean_dist1 < self.train_cfg['near_threshold']] = 1
proposal_objectness_mask[
euclidean_dist1 < self.train_cfg['near_threshold']] = 1
proposal_objectness_mask[
euclidean_dist1 > self.train_cfg['far_threshold']] = 1
objectness_label_surface[
(euclidean_dist_obj_surface <
self.train_cfg['label_surface_threshold']) *
(euclidean_dist_surface <
self.train_cfg['mask_surface_threshold'])] = 1
objectness_label_surface_sem[
(euclidean_dist_obj_surface <
self.train_cfg['label_surface_threshold']) *
(euclidean_dist_surface < self.train_cfg['mask_surface_threshold'])
* (surface_sel_sem == surface_sel_sem_gt)] = 1
objectness_label_line[
(euclidean_dist_obj_line < self.train_cfg['label_line_threshold'])
*
(euclidean_dist_line < self.train_cfg['mask_line_threshold'])] = 1
objectness_label_line_sem[
(euclidean_dist_obj_line < self.train_cfg['label_line_threshold'])
* (euclidean_dist_line < self.train_cfg['mask_line_threshold']) *
(line_sel_sem == line_sel_sem_gt)] = 1
objectness_label_surface_obj = proposal_objectness_label.repeat(6)
objectness_mask_surface_obj = proposal_objectness_mask.repeat(6)
objectness_label_line_obj = proposal_objectness_label.repeat(12)
objectness_mask_line_obj = proposal_objectness_mask.repeat(12)
objectness_mask_surface = objectness_mask_surface_obj
objectness_mask_line = objectness_mask_line_obj
cues_objectness_label = torch.cat(
(objectness_label_surface, objectness_label_line), 0)
cues_sem_label = torch.cat(
(objectness_label_surface_sem, objectness_label_line_sem), 0)
cues_mask = torch.cat((objectness_mask_surface, objectness_mask_line),
0)
objectness_label_surface *= objectness_label_surface_obj
objectness_label_line *= objectness_label_line_obj
cues_matching_label = torch.cat(
(objectness_label_surface, objectness_label_line), 0)
objectness_label_surface_sem *= objectness_label_surface_obj
objectness_label_line_sem *= objectness_label_line_obj
cues_match_mask = (torch.sum(
cues_objectness_label.view(18, num_proposals), dim=0) >=
1).float()
obj_surface_line_center = torch.cat(
(obj_surface_center, obj_line_center), 1).squeeze(0)
return (cues_objectness_label, cues_sem_label,
proposal_objectness_label, cues_mask, cues_match_mask,
proposal_objectness_mask, cues_matching_label,
obj_surface_line_center)
from mmdet3d.core.bbox import bbox3d2result
from mmdet.models import HEADS
from ..builder import build_head
from .base_3droi_head import Base3DRoIHead
@HEADS.register_module()
class H3DRoIHead(Base3DRoIHead):
"""H3D roi head for H3DNet.
Args:
primitive_list (List): Configs of primitive heads.
bbox_head (ConfigDict): Config of bbox_head.
train_cfg (ConfigDict): Training config.
test_cfg (ConfigDict): Testing config.
"""
def __init__(self,
primitive_list,
bbox_head=None,
train_cfg=None,
test_cfg=None):
super(H3DRoIHead, self).__init__(
bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg)
# Primitive module
assert len(primitive_list) == 3
self.primitive_z = build_head(primitive_list[0])
self.primitive_xy = build_head(primitive_list[1])
self.primitive_line = build_head(primitive_list[2])
def init_weights(self, pretrained):
"""Initialize weights, skip since ``H3DROIHead`` does not need to
initialize weights."""
pass
def init_mask_head(self):
"""Initialize mask head, skip since ``H3DROIHead`` does not have
one."""
pass
def init_bbox_head(self, bbox_head):
"""Initialize box head."""
bbox_head['train_cfg'] = self.train_cfg
bbox_head['test_cfg'] = self.test_cfg
self.bbox_head = build_head(bbox_head)
def init_assigner_sampler(self):
"""Initialize assigner and sampler."""
pass
def forward_train(self,
feats_dict,
img_metas,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask,
pts_instance_mask,
gt_bboxes_ignore=None):
"""Training forward function of PartAggregationROIHead.
Args:
feats_dict (dict): Contains features from the first stage.
img_metas (list[dict]): Contain pcd and img's meta info.
points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
bboxes of each sample.
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
pts_semantic_mask (None | list[torch.Tensor]): Point-wise
semantic mask.
pts_instance_mask (None | list[torch.Tensor]): Point-wise
instance mask.
gt_bboxes_ignore (None | list[torch.Tensor]): Specify
which bounding.
Returns:
dict: losses from each head.
"""
losses = dict()
sample_mod = self.train_cfg.sample_mod
assert sample_mod in ['vote', 'seed', 'random']
result_z = self.primitive_z(feats_dict, sample_mod)
feats_dict.update(result_z)
result_xy = self.primitive_xy(feats_dict, sample_mod)
feats_dict.update(result_xy)
result_line = self.primitive_line(feats_dict, sample_mod)
feats_dict.update(result_line)
primitive_loss_inputs = (feats_dict, points, gt_bboxes_3d,
gt_labels_3d, pts_semantic_mask,
pts_instance_mask, img_metas,
gt_bboxes_ignore)
loss_z = self.primitive_z.loss(*primitive_loss_inputs)
losses.update(loss_z)
loss_xy = self.primitive_xy.loss(*primitive_loss_inputs)
losses.update(loss_xy)
loss_line = self.primitive_line.loss(*primitive_loss_inputs)
losses.update(loss_line)
targets = feats_dict.pop('targets')
bbox_results = self.bbox_head(feats_dict, sample_mod)
feats_dict.update(bbox_results)
bbox_loss = self.bbox_head.loss(feats_dict, points, gt_bboxes_3d,
gt_labels_3d, pts_semantic_mask,
pts_instance_mask, img_metas, targets,
gt_bboxes_ignore)
losses.update(bbox_loss)
return losses
def simple_test(self, feats_dict, img_metas, points, rescale=False):
"""Simple testing forward function of PartAggregationROIHead.
Note:
This function assumes that the batch size is 1
Args:
feats_dict (dict): Contains features from the first stage.
img_metas (list[dict]): Contain pcd and img's meta info.
points (torch.Tensor): Input points.
rescale (bool): Whether to rescale results.
Returns:
dict: Bbox results of one frame.
"""
sample_mod = self.test.sample_mod
assert sample_mod in ['vote', 'seed', 'random']
result_z = self.primitive_z(feats_dict, sample_mod)
feats_dict.update(result_z)
result_xy = self.primitive_xy(feats_dict, sample_mod)
feats_dict.update(result_xy)
result_line = self.primitive_line(feats_dict, sample_mod)
feats_dict.update(result_line)
bbox_preds = self.bbox_head(feats_dict, sample_mod)
feats_dict.update(bbox_preds)
bbox_list = self.bbox_head.get_bboxes(
points,
feats_dict,
img_metas,
rescale=rescale,
suffix='_optimized')
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results[0]
...@@ -1135,6 +1135,54 @@ def test_depth_boxes3d(): ...@@ -1135,6 +1135,54 @@ def test_depth_boxes3d():
dtype=torch.int32) dtype=torch.int32)
assert torch.all(box_idxs_of_pts == expected_idxs_of_pts) assert torch.all(box_idxs_of_pts == expected_idxs_of_pts)
# test get_surface_line_center
boxes = torch.tensor(
[[0.3294, 1.0359, 0.1171, 1.0822, 1.1247, 1.3721, 0.4916],
[-2.4630, -2.6324, -0.1616, 0.9202, 1.7896, 0.1992, 0.3185]])
boxes = DepthInstance3DBoxes(
boxes, box_dim=boxes.shape[-1], with_yaw=True, origin=(0.5, 0.5, 0.5))
surface_center, line_center = boxes.get_surface_line_center()
expected_surface_center = torch.tensor([[0.3294, 1.0359, 0.8031],
[-2.4630, -2.6324, -0.0620],
[0.3294, 1.0359, -0.5689],
[-2.4630, -2.6324, -0.2612],
[0.5949, 1.5317, 0.1171],
[-2.1828, -1.7826, -0.1616],
[0.0640, 0.5401, 0.1171],
[-2.7432, -3.4822, -0.1616],
[0.8064, 0.7805, 0.1171],
[-2.0260, -2.7765, -0.1616],
[-0.1476, 1.2913, 0.1171],
[-2.9000, -2.4883, -0.1616]])
expected_line_center = torch.tensor([[0.8064, 0.7805, 0.8031],
[-2.0260, -2.7765, -0.0620],
[-0.1476, 1.2913, 0.8031],
[-2.9000, -2.4883, -0.0620],
[0.5949, 1.5317, 0.8031],
[-2.1828, -1.7826, -0.0620],
[0.0640, 0.5401, 0.8031],
[-2.7432, -3.4822, -0.0620],
[0.8064, 0.7805, -0.5689],
[-2.0260, -2.7765, -0.2612],
[-0.1476, 1.2913, -0.5689],
[-2.9000, -2.4883, -0.2612],
[0.5949, 1.5317, -0.5689],
[-2.1828, -1.7826, -0.2612],
[0.0640, 0.5401, -0.5689],
[-2.7432, -3.4822, -0.2612],
[1.0719, 1.2762, 0.1171],
[-1.7458, -1.9267, -0.1616],
[0.5410, 0.2847, 0.1171],
[-2.3062, -3.6263, -0.1616],
[0.1178, 1.7871, 0.1171],
[-2.6198, -1.6385, -0.1616],
[-0.4131, 0.7956, 0.1171],
[-3.1802, -3.3381, -0.1616]])
assert torch.allclose(surface_center, expected_surface_center, atol=1e-04)
assert torch.allclose(line_center, expected_line_center, atol=1e-04)
def test_rotation_3d_in_axis(): def test_rotation_3d_in_axis():
points = torch.tensor([[[-0.4599, -0.0471, 0.0000], points = torch.tensor([[[-0.4599, -0.0471, 0.0000],
......
...@@ -569,3 +569,92 @@ def test_primitive_head(): ...@@ -569,3 +569,92 @@ def test_primitive_head():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
primitive_head_cfg['vote_moudule_cfg']['in_channels'] = 'xyz' primitive_head_cfg['vote_moudule_cfg']['in_channels'] = 'xyz'
build_head(primitive_head_cfg) build_head(primitive_head_cfg)
def test_h3d_head():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
_setup_seed(0)
h3d_head_cfg = _get_roi_head_cfg('h3dnet/h3dnet_8x8_scannet-3d-18class.py')
self = build_head(h3d_head_cfg).cuda()
# prepare roi outputs
fp_xyz = [torch.rand([1, 1024, 3], dtype=torch.float32).cuda()]
hd_features = torch.rand([1, 256, 1024], dtype=torch.float32).cuda()
fp_indices = [torch.randint(0, 128, [1, 1024]).cuda()]
aggregated_points = torch.rand([1, 256, 3], dtype=torch.float32).cuda()
aggregated_features = torch.rand([1, 128, 256], dtype=torch.float32).cuda()
rpn_proposals = torch.cat([
torch.rand([1, 256, 3], dtype=torch.float32).cuda() * 4 - 2,
torch.rand([1, 256, 3], dtype=torch.float32).cuda() * 4,
torch.zeros([1, 256, 1]).cuda()
],
dim=-1)
input_dict = dict(
fp_xyz_net0=fp_xyz,
hd_feature=hd_features,
aggregated_points=aggregated_points,
aggregated_features=aggregated_features,
seed_points=fp_xyz[0],
seed_indices=fp_indices[0],
rpn_proposals=rpn_proposals)
# prepare gt label
from mmdet3d.core.bbox import DepthInstance3DBoxes
gt_bboxes_3d = [
DepthInstance3DBoxes(torch.rand([4, 7], dtype=torch.float32).cuda()),
DepthInstance3DBoxes(torch.rand([4, 7], dtype=torch.float32).cuda())
]
gt_labels_3d = torch.randint(0, 18, [1, 4]).cuda()
gt_labels_3d = [gt_labels_3d[0]]
pts_semantic_mask = torch.randint(0, 19, [1, 1024]).cuda()
pts_semantic_mask = [pts_semantic_mask[0]]
pts_instance_mask = torch.randint(0, 4, [1, 1024]).cuda()
pts_instance_mask = [pts_instance_mask[0]]
points = torch.rand([1, 1024, 3], dtype=torch.float32).cuda()
# prepare rpn targets
vote_targets = torch.rand([1, 1024, 9], dtype=torch.float32).cuda()
vote_target_masks = torch.rand([1, 1024], dtype=torch.float32).cuda()
size_class_targets = torch.rand([1, 256],
dtype=torch.float32).cuda().long()
size_res_targets = torch.rand([1, 256, 3], dtype=torch.float32).cuda()
dir_class_targets = torch.rand([1, 256], dtype=torch.float32).cuda().long()
dir_res_targets = torch.rand([1, 256], dtype=torch.float32).cuda()
center_targets = torch.rand([1, 4, 3], dtype=torch.float32).cuda()
mask_targets = torch.rand([1, 256], dtype=torch.float32).cuda().long()
valid_gt_masks = torch.rand([1, 4], dtype=torch.float32).cuda()
objectness_targets = torch.rand([1, 256],
dtype=torch.float32).cuda().long()
objectness_weights = torch.rand([1, 256], dtype=torch.float32).cuda()
box_loss_weights = torch.rand([1, 256], dtype=torch.float32).cuda()
valid_gt_weights = torch.rand([1, 4], dtype=torch.float32).cuda()
targets = (vote_targets, vote_target_masks, size_class_targets,
size_res_targets, dir_class_targets, dir_res_targets,
center_targets, mask_targets, valid_gt_masks,
objectness_targets, objectness_weights, box_loss_weights,
valid_gt_weights)
input_dict['targets'] = targets
# train forward
ret_dict = self.forward_train(
input_dict,
'vote',
points=points,
gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=gt_labels_3d,
pts_semantic_mask=pts_semantic_mask,
pts_instance_mask=pts_instance_mask,
img_metas=None)
assert ret_dict['flag_loss_z'] >= 0
assert ret_dict['vote_loss_z'] >= 0
assert ret_dict['center_loss_z'] >= 0
assert ret_dict['size_loss_z'] >= 0
assert ret_dict['sem_loss_z'] >= 0
assert ret_dict['objectness_loss_opt'] >= 0
assert ret_dict['primitive_sem_matching_loss'] >= 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment