Unverified Commit 460f6b3b authored by encore-zhou's avatar encore-zhou Committed by GitHub
Browse files

[feature]: support ssd_3d_head in 3DSSD (#83)

* add ssd3dhead

* fix bugs for anchorfreebboxcoder

* modify ssd 3d head

* modify ssd 3d head

* reconstruct ssd3dhead and votehead

* add unittest

* modify 3dssd config

* modify 3dssd head

* modify 3dssd head

* rename base conv bbox head

* modify vote module

* modify 3dssd config

* fix bugs for unittest

* modify test_heads.py

* fix bugs for h3d bbox head

* add 3dssd detector

* fix bugs for 3dssd config

* modify base conv bbox head

* modify base conv bbox head

* modify base conv bbox head
parent 0e2ad8df
_base_ = [
'../_base_/models/3dssd.py', '../_base_/datasets/kitti-3d-car.py',
'../_base_/default_runtime.py'
]
# dataset settings
dataset_type = 'KittiDataset'
data_root = 'data/kitti/'
class_names = ['Car']
point_cloud_range = [0, -40, -5, 70, 40, 3]
input_modality = dict(use_lidar=True, use_camera=False)
db_sampler = dict(
data_root=data_root,
info_path=data_root + 'kitti_dbinfos_train.pkl',
rate=1.0,
prepare=dict(filter_by_difficulty=[-1], filter_by_min_points=dict(Car=5)),
classes=class_names,
sample_groups=dict(Car=15))
file_client_args = dict(backend='disk')
# Uncomment the following if use ceph or other file clients.
# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient
# for more details.
# file_client_args = dict(
# backend='petrel', path_mapping=dict(data='s3://kitti_data/'))
train_pipeline = [
dict(
type='LoadPointsFromFile',
load_dim=4,
use_dim=4,
file_client_args=file_client_args),
dict(
type='LoadAnnotations3D',
with_bbox_3d=True,
with_label_3d=True,
file_client_args=file_client_args),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectSample', db_sampler=db_sampler),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(
type='ObjectNoise',
num_try=100,
translation_std=[1.0, 1.0, 0],
global_rot_range=[0.0, 0.0],
rot_range=[-1.0471975511965976, 1.0471975511965976]),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.9, 1.1]),
dict(type='BackgroundPointsFilter', bbox_enlarge_range=(0.5, 2.0, 0.5)),
dict(type='IndoorPointSample', num_points=16384),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
load_dim=4,
use_dim=4,
file_client_args=file_client_args),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='IndoorPointSample', num_points=16384),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(dataset=dict(pipeline=train_pipeline)),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
evaluation = dict(interval=2)
# model settings
model = dict(
bbox_head=dict(
num_classes=1,
bbox_coder=dict(
type='AnchorFreeBBoxCoder', num_dir_bins=12, with_rot=True)))
# optimizer
lr = 0.002 # max learning rate
optimizer = dict(type='AdamW', lr=lr, weight_decay=0)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
lr_config = dict(policy='step', warmup=None, step=[80, 120])
# runtime settings
total_epochs = 150
# yapf:disable
log_config = dict(
interval=30,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')
])
# yapf:enable
model = dict(
type='SSD3DNet',
backbone=dict(
type='PointNet2SAMSG',
in_channels=4,
num_points=(4096, 512, (256, 256)),
radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)),
num_samples=((32, 32, 64), (32, 32, 64), (32, 32, 32)),
sa_channels=(((16, 16, 32), (16, 16, 32), (32, 32, 64)),
((64, 64, 128), (64, 64, 128), (64, 96, 128)),
((128, 128, 256), (128, 192, 256), (128, 256, 256))),
aggregation_channels=(64, 128, 256),
fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')),
fps_sample_range_lists=((-1), (-1), (512, -1)),
norm_cfg=dict(type='BN2d', eps=1e-3, momentum=0.1),
sa_cfg=dict(
type='PointSAModuleMSG',
pool_mod='max',
use_xyz=True,
normalize_xyz=False)),
bbox_head=dict(
type='SSD3DHead',
in_channels=256,
vote_module_cfg=dict(
in_channels=256,
num_points=256,
gt_per_seed=1,
conv_channels=(128, ),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.1),
with_res_feat=False,
vote_xyz_range=(3.0, 3.0, 2.0)),
vote_aggregation_cfg=dict(
type='PointSAModuleMSG',
num_point=256,
radii=(4.8, 6.4),
sample_nums=(16, 32),
mlp_channels=((256, 256, 256, 512), (256, 256, 512, 1024)),
norm_cfg=dict(type='BN2d', eps=1e-3, momentum=0.1),
use_xyz=True,
normalize_xyz=False,
bias=True),
pred_layer_cfg=dict(
in_channels=1536,
shared_conv_channels=(512, 128),
cls_conv_layers=(128, ),
reg_conv_layers=(128, ),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.1),
bias=True),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.1),
objectness_loss=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
center_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=1.0),
dir_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
dir_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=1.0),
size_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=1.0),
corner_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=1.0),
vote_loss=dict(type='SmoothL1Loss', reduction='sum', loss_weight=1.0)))
# model training and testing settings
train_cfg = dict(
sample_mod='spec', pos_distance_thr=10.0, expand_dims_length=0.05)
test_cfg = dict(
nms_cfg=dict(type='nms', iou_thr=0.1),
sample_mod='spec',
score_thr=0.0,
per_class_proposal=True,
max_output_num=100)
# optimizer
# This schedule is mainly used by models on indoor dataset,
# e.g., VoteNet on SUNRGBD and ScanNet
lr = 0.002 # max learning rate
optimizer = dict(type='AdamW', lr=lr, weight_decay=0)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
lr_config = dict(policy='step', warmup=None, step=[80, 120])
# runtime settings
total_epochs = 150
...@@ -5,7 +5,7 @@ primitive_z_cfg = dict( ...@@ -5,7 +5,7 @@ primitive_z_cfg = dict(
primitive_mode='z', primitive_mode='z',
upper_thresh=100.0, upper_thresh=100.0,
surface_thresh=0.5, surface_thresh=0.5,
vote_moudule_cfg=dict( vote_module_cfg=dict(
in_channels=256, in_channels=256,
vote_per_seed=1, vote_per_seed=1,
gt_per_seed=1, gt_per_seed=1,
...@@ -63,7 +63,7 @@ primitive_xy_cfg = dict( ...@@ -63,7 +63,7 @@ primitive_xy_cfg = dict(
primitive_mode='xy', primitive_mode='xy',
upper_thresh=100.0, upper_thresh=100.0,
surface_thresh=0.5, surface_thresh=0.5,
vote_moudule_cfg=dict( vote_module_cfg=dict(
in_channels=256, in_channels=256,
vote_per_seed=1, vote_per_seed=1,
gt_per_seed=1, gt_per_seed=1,
...@@ -121,7 +121,7 @@ primitive_line_cfg = dict( ...@@ -121,7 +121,7 @@ primitive_line_cfg = dict(
primitive_mode='line', primitive_mode='line',
upper_thresh=100.0, upper_thresh=100.0,
surface_thresh=0.5, surface_thresh=0.5,
vote_moudule_cfg=dict( vote_module_cfg=dict(
in_channels=256, in_channels=256,
vote_per_seed=1, vote_per_seed=1,
gt_per_seed=1, gt_per_seed=1,
...@@ -198,7 +198,7 @@ model = dict( ...@@ -198,7 +198,7 @@ model = dict(
normalize_xyz=True))), normalize_xyz=True))),
rpn_head=dict( rpn_head=dict(
type='VoteHead', type='VoteHead',
vote_moudule_cfg=dict( vote_module_cfg=dict(
in_channels=256, in_channels=256,
vote_per_seed=1, vote_per_seed=1,
gt_per_seed=3, gt_per_seed=3,
...@@ -219,7 +219,7 @@ model = dict( ...@@ -219,7 +219,7 @@ model = dict(
mlp_channels=[256, 128, 128, 128], mlp_channels=[256, 128, 128, 128],
use_xyz=True, use_xyz=True,
normalize_xyz=True), normalize_xyz=True),
feat_channels=(128, 128), pred_layer_cfg=dict(in_channels=128, shared_conv_channels=(128, 128)),
conv_cfg=dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
objectness_loss=dict( objectness_loss=dict(
......
...@@ -17,7 +17,7 @@ model = dict( ...@@ -17,7 +17,7 @@ model = dict(
normalize_xyz=True)), normalize_xyz=True)),
bbox_head=dict( bbox_head=dict(
type='VoteHead', type='VoteHead',
vote_moudule_cfg=dict( vote_module_cfg=dict(
in_channels=256, in_channels=256,
vote_per_seed=1, vote_per_seed=1,
gt_per_seed=3, gt_per_seed=3,
...@@ -38,7 +38,7 @@ model = dict( ...@@ -38,7 +38,7 @@ model = dict(
mlp_channels=[256, 128, 128, 128], mlp_channels=[256, 128, 128, 128],
use_xyz=True, use_xyz=True,
normalize_xyz=True), normalize_xyz=True),
feat_channels=(128, 128), pred_layer_cfg=dict(in_channels=128, shared_conv_channels=(128, 128)),
conv_cfg=dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
objectness_loss=dict( objectness_loss=dict(
......
from mmdet.core.bbox import build_bbox_coder from mmdet.core.bbox import build_bbox_coder
from .anchor_free_bbox_coder import AnchorFreeBBoxCoder
from .delta_xyzwhlr_bbox_coder import DeltaXYZWLHRBBoxCoder from .delta_xyzwhlr_bbox_coder import DeltaXYZWLHRBBoxCoder
from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder
__all__ = [ __all__ = [
'build_bbox_coder', 'DeltaXYZWLHRBBoxCoder', 'PartialBinBasedBBoxCoder' 'build_bbox_coder', 'DeltaXYZWLHRBBoxCoder', 'PartialBinBasedBBoxCoder',
'AnchorFreeBBoxCoder'
] ]
import numpy as np
import torch
from mmdet.core.bbox.builder import BBOX_CODERS
from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder
@BBOX_CODERS.register_module()
class AnchorFreeBBoxCoder(PartialBinBasedBBoxCoder):
"""Anchor free bbox coder for 3D boxes.
Args:
num_dir_bins (int): Number of bins to encode direction angle.
with_rot (bool): Whether the bbox is with rotation.
"""
def __init__(self, num_dir_bins, with_rot=True):
super(AnchorFreeBBoxCoder, self).__init__(
num_dir_bins, 0, [], with_rot=with_rot)
self.num_dir_bins = num_dir_bins
self.with_rot = with_rot
def encode(self, gt_bboxes_3d, gt_labels_3d):
"""Encode ground truth to prediction targets.
Args:
gt_bboxes_3d (BaseInstance3DBoxes): Ground truth bboxes \
with shape (n, 7).
gt_labels_3d (torch.Tensor): Ground truth classes.
Returns:
tuple: Targets of center, size and direction.
"""
# generate center target
center_target = gt_bboxes_3d.gravity_center
# generate bbox size target
size_res_target = gt_bboxes_3d.dims / 2
# generate dir target
box_num = gt_labels_3d.shape[0]
if self.with_rot:
(dir_class_target,
dir_res_target) = self.angle2class(gt_bboxes_3d.yaw)
dir_res_target /= (2 * np.pi / self.num_dir_bins)
else:
dir_class_target = gt_labels_3d.new_zeros(box_num)
dir_res_target = gt_bboxes_3d.tensor.new_zeros(box_num)
return (center_target, size_res_target, dir_class_target,
dir_res_target)
def decode(self, bbox_out):
"""Decode predicted parts to bbox3d.
Args:
bbox_out (dict): Predictions from model, should contain keys below.
- center: predicted bottom center of bboxes.
- dir_class: predicted bbox direction class.
- dir_res: predicted bbox direction residual.
- size: predicted bbox size.
Returns:
torch.Tensor: Decoded bbox3d with shape (batch, n, 7).
"""
center = bbox_out['center']
batch_size, num_proposal = center.shape[:2]
# decode heading angle
if self.with_rot:
dir_class = torch.argmax(bbox_out['dir_class'], -1)
dir_res = torch.gather(bbox_out['dir_res'], 2,
dir_class.unsqueeze(-1))
dir_res.squeeze_(2)
dir_angle = self.class2angle(dir_class, dir_res).reshape(
batch_size, num_proposal, 1)
else:
dir_angle = center.new_zeros(batch_size, num_proposal, 1)
# decode bbox size
bbox_size = torch.clamp(bbox_out['size'] * 2, min=0.1)
bbox3d = torch.cat([center, bbox_size, dir_angle], dim=-1)
return bbox3d
def split_pred(self, cls_preds, reg_preds, base_xyz):
"""Split predicted features to specific parts.
Args:
cls_preds (torch.Tensor): Class predicted features to split.
reg_preds (torch.Tensor): Regression predicted features to split.
base_xyz (torch.Tensor): Coordinates of points.
Returns:
dict[str, torch.Tensor]: Split results.
"""
results = {}
results['obj_scores'] = cls_preds
start, end = 0, 0
reg_preds_trans = reg_preds.transpose(2, 1)
# decode center
end += 3
# (batch_size, num_proposal, 3)
results['center_offset'] = reg_preds_trans[..., start:end]
results['center'] = base_xyz.detach() + reg_preds_trans[..., start:end]
start = end
# decode center
end += 3
# (batch_size, num_proposal, 3)
results['size'] = reg_preds_trans[..., start:end]
start = end
# decode direction
end += self.num_dir_bins
results['dir_class'] = reg_preds_trans[..., start:end]
start = end
end += self.num_dir_bins
dir_res_norm = reg_preds_trans[..., start:end]
start = end
results['dir_res_norm'] = dir_res_norm
results['dir_res'] = dir_res_norm * (2 * np.pi / self.num_dir_bins)
return results
...@@ -98,11 +98,12 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder): ...@@ -98,11 +98,12 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder):
bbox3d = torch.cat([center, bbox_size, dir_angle], dim=-1) bbox3d = torch.cat([center, bbox_size, dir_angle], dim=-1)
return bbox3d return bbox3d
def split_pred(self, preds, base_xyz): def split_pred(self, cls_preds, reg_preds, base_xyz):
"""Split predicted features to specific parts. """Split predicted features to specific parts.
Args: Args:
preds (torch.Tensor): Predicted features to split. cls_preds (torch.Tensor): Class predicted features to split.
reg_preds (torch.Tensor): Regression predicted features to split.
base_xyz (torch.Tensor): Coordinates of points. base_xyz (torch.Tensor): Coordinates of points.
Returns: Returns:
...@@ -110,26 +111,24 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder): ...@@ -110,26 +111,24 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder):
""" """
results = {} results = {}
start, end = 0, 0 start, end = 0, 0
preds_trans = preds.transpose(2, 1)
# decode objectness score cls_preds_trans = cls_preds.transpose(2, 1)
end += 2 reg_preds_trans = reg_preds.transpose(2, 1)
results['obj_scores'] = preds_trans[..., start:end].contiguous()
start = end
# decode center # decode center
end += 3 end += 3
# (batch_size, num_proposal, 3) # (batch_size, num_proposal, 3)
results['center'] = base_xyz + preds_trans[..., start:end].contiguous() results['center'] = base_xyz + \
reg_preds_trans[..., start:end].contiguous()
start = end start = end
# decode direction # decode direction
end += self.num_dir_bins end += self.num_dir_bins
results['dir_class'] = preds_trans[..., start:end].contiguous() results['dir_class'] = reg_preds_trans[..., start:end].contiguous()
start = end start = end
end += self.num_dir_bins end += self.num_dir_bins
dir_res_norm = preds_trans[..., start:end].contiguous() dir_res_norm = reg_preds_trans[..., start:end].contiguous()
start = end start = end
results['dir_res_norm'] = dir_res_norm results['dir_res_norm'] = dir_res_norm
...@@ -137,23 +136,29 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder): ...@@ -137,23 +136,29 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder):
# decode size # decode size
end += self.num_sizes end += self.num_sizes
results['size_class'] = preds_trans[..., start:end].contiguous() results['size_class'] = reg_preds_trans[..., start:end].contiguous()
start = end start = end
end += self.num_sizes * 3 end += self.num_sizes * 3
size_res_norm = preds_trans[..., start:end] size_res_norm = reg_preds_trans[..., start:end]
batch_size, num_proposal = preds_trans.shape[:2] batch_size, num_proposal = reg_preds_trans.shape[:2]
size_res_norm = size_res_norm.view( size_res_norm = size_res_norm.view(
[batch_size, num_proposal, self.num_sizes, 3]) [batch_size, num_proposal, self.num_sizes, 3])
start = end start = end
results['size_res_norm'] = size_res_norm.contiguous() results['size_res_norm'] = size_res_norm.contiguous()
mean_sizes = preds.new_tensor(self.mean_sizes) mean_sizes = reg_preds.new_tensor(self.mean_sizes)
results['size_res'] = ( results['size_res'] = (
size_res_norm * mean_sizes.unsqueeze(0).unsqueeze(0)) size_res_norm * mean_sizes.unsqueeze(0).unsqueeze(0))
# decode objectness score
start = 0
end = 2
results['obj_scores'] = cls_preds_trans[..., start:end].contiguous()
start = end
# decode semantic score # decode semantic score
results['sem_scores'] = preds_trans[..., start:].contiguous() results['sem_scores'] = cls_preds_trans[..., start:].contiguous()
return results return results
......
...@@ -252,6 +252,21 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes): ...@@ -252,6 +252,21 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
return box_idxs_of_pts.squeeze(0) return box_idxs_of_pts.squeeze(0)
def enlarged_box(self, extra_width):
"""Enlarge the length, width and height boxes.
Args:
extra_width (float | torch.Tensor): Extra width to enlarge the box.
Returns:
:obj:`LiDARInstance3DBoxes`: Enlarged boxes.
"""
enlarged_boxes = self.tensor.clone()
enlarged_boxes[:, 3:6] += extra_width * 2
# bottom center z minus extra_width
enlarged_boxes[:, 2] -= extra_width
return self.new_box(enlarged_boxes)
def get_surface_line_center(self): def get_surface_line_center(self):
"""Compute surface and line center of bounding boxes. """Compute surface and line center of bounding boxes.
......
...@@ -475,6 +475,7 @@ def eval_class(gt_annos, ...@@ -475,6 +475,7 @@ def eval_class(gt_annos,
if num_examples < num_parts: if num_examples < num_parts:
num_parts = num_examples num_parts = num_examples
split_parts = get_split_parts(num_examples, num_parts) split_parts = get_split_parts(num_examples, num_parts)
rets = calculate_iou_partly(dt_annos, gt_annos, metric, num_parts) rets = calculate_iou_partly(dt_annos, gt_annos, metric, num_parts)
overlaps, parted_overlaps, total_dt_num, total_gt_num = rets overlaps, parted_overlaps, total_dt_num, total_gt_num = rets
N_SAMPLE_PTS = 41 N_SAMPLE_PTS = 41
......
...@@ -24,6 +24,7 @@ class PointNet2SAMSG(BasePointNet): ...@@ -24,6 +24,7 @@ class PointNet2SAMSG(BasePointNet):
fps_mods (tuple[int]): Mod of FPS for each SA module. fps_mods (tuple[int]): Mod of FPS for each SA module.
fps_sample_range_lists (tuple[tuple[int]]): The number of sampling fps_sample_range_lists (tuple[tuple[int]]): The number of sampling
points which each SA module samples. points which each SA module samples.
dilated_group (tuple[bool]): Whether to use dilated ball query for
out_indices (Sequence[int]): Output from which stages. out_indices (Sequence[int]): Output from which stages.
norm_cfg (dict): Config of normalization layer. norm_cfg (dict): Config of normalization layer.
sa_cfg (dict): Config of set abstraction module, which may contain sa_cfg (dict): Config of set abstraction module, which may contain
...@@ -47,13 +48,14 @@ class PointNet2SAMSG(BasePointNet): ...@@ -47,13 +48,14 @@ class PointNet2SAMSG(BasePointNet):
aggregation_channels=(64, 128, 256), aggregation_channels=(64, 128, 256),
fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')), fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')),
fps_sample_range_lists=((-1), (-1), (512, -1)), fps_sample_range_lists=((-1), (-1), (512, -1)),
dilated_group=(True, True, True),
out_indices=(2, ), out_indices=(2, ),
norm_cfg=dict(type='BN2d'), norm_cfg=dict(type='BN2d'),
sa_cfg=dict( sa_cfg=dict(
type='PointSAModuleMSG', type='PointSAModuleMSG',
pool_mod='max', pool_mod='max',
use_xyz=True, use_xyz=True,
normalize_xyz=True)): normalize_xyz=False)):
super().__init__() super().__init__()
self.num_sa = len(sa_channels) self.num_sa = len(sa_channels)
self.out_indices = out_indices self.out_indices = out_indices
...@@ -94,6 +96,7 @@ class PointNet2SAMSG(BasePointNet): ...@@ -94,6 +96,7 @@ class PointNet2SAMSG(BasePointNet):
mlp_channels=cur_sa_mlps, mlp_channels=cur_sa_mlps,
fps_mod=cur_fps_mod, fps_mod=cur_fps_mod,
fps_sample_range_list=cur_fps_sample_range_list, fps_sample_range_list=cur_fps_sample_range_list,
dilated_group=dilated_group[sa_index],
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
cfg=sa_cfg, cfg=sa_cfg,
bias=True)) bias=True))
...@@ -137,6 +140,7 @@ class PointNet2SAMSG(BasePointNet): ...@@ -137,6 +140,7 @@ class PointNet2SAMSG(BasePointNet):
out_sa_xyz = [] out_sa_xyz = []
out_sa_features = [] out_sa_features = []
out_sa_indices = [] out_sa_indices = []
for i in range(self.num_sa): for i in range(self.num_sa):
cur_xyz, cur_features, cur_indices = self.SA_modules[i]( cur_xyz, cur_features, cur_indices = self.SA_modules[i](
sa_xyz[i], sa_features[i]) sa_xyz[i], sa_features[i])
......
from .anchor3d_head import Anchor3DHead from .anchor3d_head import Anchor3DHead
from .base_conv_bbox_head import BaseConvBboxHead
from .free_anchor3d_head import FreeAnchor3DHead from .free_anchor3d_head import FreeAnchor3DHead
from .parta2_rpn_head import PartA2RPNHead from .parta2_rpn_head import PartA2RPNHead
from .ssd_3d_head import SSD3DHead
from .vote_head import VoteHead from .vote_head import VoteHead
__all__ = ['Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead'] __all__ = [
'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead',
'SSD3DHead', 'BaseConvBboxHead'
]
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import build_conv_layer
from torch import nn as nn
from mmdet.models.builder import HEADS
@HEADS.register_module()
class BaseConvBboxHead(nn.Module):
r"""More general bbox head, with shared conv layers and two optional
separated branches.
.. code-block:: none
/-> cls convs -> cls_score
shared convs
\-> reg convs -> bbox_pred
"""
def __init__(self,
in_channels=0,
shared_conv_channels=(),
cls_conv_channels=(),
num_cls_out_channels=0,
reg_conv_channels=(),
num_reg_out_channels=0,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'),
bias='auto',
*args,
**kwargs):
super(BaseConvBboxHead, self).__init__(*args, **kwargs)
assert in_channels > 0
assert num_cls_out_channels > 0
assert num_reg_out_channels > 0
self.in_channels = in_channels
self.shared_conv_channels = shared_conv_channels
self.cls_conv_channels = cls_conv_channels
self.num_cls_out_channels = num_cls_out_channels
self.reg_conv_channels = reg_conv_channels
self.num_reg_out_channels = num_reg_out_channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.bias = bias
# add shared convs
if len(self.shared_conv_channels) > 0:
self.shared_convs = self._add_conv_branch(
self.in_channels, self.shared_conv_channels)
out_channels = self.shared_conv_channels[-1]
else:
out_channels = self.in_channels
# add cls specific branch
prev_channel = out_channels
if len(self.cls_conv_channels) > 0:
self.cls_convs = self._add_conv_branch(prev_channel,
self.cls_conv_channels)
prev_channel = self.cls_conv_channels[-1]
self.conv_cls = build_conv_layer(
conv_cfg,
in_channels=prev_channel,
out_channels=num_cls_out_channels,
kernel_size=1)
# add reg specific branch
prev_channel = out_channels
if len(self.reg_conv_channels) > 0:
self.reg_convs = self._add_conv_branch(prev_channel,
self.reg_conv_channels)
prev_channel = self.reg_conv_channels[-1]
self.conv_reg = build_conv_layer(
conv_cfg,
in_channels=prev_channel,
out_channels=num_reg_out_channels,
kernel_size=1)
def _add_conv_branch(self, in_channels, conv_channels):
"""Add shared or separable branch."""
conv_spec = [in_channels] + list(conv_channels)
# add branch specific conv layers
conv_layers = nn.Sequential()
for i in range(len(conv_spec) - 1):
conv_layers.add_module(
f'layer{i}',
ConvModule(
conv_spec[i],
conv_spec[i + 1],
kernel_size=1,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
bias=self.bias,
inplace=True))
return conv_layers
def init_weights(self):
# conv layers are already initialized by ConvModule
pass
def forward(self, feats):
"""Forward.
Args:
feats (Tensor): Input features
Returns:
Tensor: Class scores predictions
Tensor: Regression predictions
"""
# shared part
if len(self.shared_conv_channels) > 0:
x = self.shared_convs(feats)
# separate branches
x_cls = x
x_reg = x
if len(self.cls_conv_channels) > 0:
x_cls = self.cls_convs(x_cls)
cls_score = self.conv_cls(x_cls)
if len(self.reg_conv_channels) > 0:
x_reg = self.reg_convs(x_reg)
bbox_pred = self.conv_reg(x_reg)
return cls_score, bbox_pred
This diff is collapsed.
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import ConvModule
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
...@@ -11,6 +10,7 @@ from mmdet3d.models.model_utils import VoteModule ...@@ -11,6 +10,7 @@ from mmdet3d.models.model_utils import VoteModule
from mmdet3d.ops import build_sa_module, furthest_point_sample from mmdet3d.ops import build_sa_module, furthest_point_sample
from mmdet.core import build_bbox_coder, multi_apply from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS from mmdet.models import HEADS
from .base_conv_bbox_head import BaseConvBboxHead
@HEADS.register_module() @HEADS.register_module()
...@@ -23,10 +23,10 @@ class VoteHead(nn.Module): ...@@ -23,10 +23,10 @@ class VoteHead(nn.Module):
decoding boxes. decoding boxes.
train_cfg (dict): Config for training. train_cfg (dict): Config for training.
test_cfg (dict): Config for testing. test_cfg (dict): Config for testing.
vote_moudule_cfg (dict): Config of VoteModule for point-wise votes. vote_module_cfg (dict): Config of VoteModule for point-wise votes.
vote_aggregation_cfg (dict): Config of vote aggregation layer. vote_aggregation_cfg (dict): Config of vote aggregation layer.
feat_channels (tuple[int]): Convolution channels of pred_layer_cfg (dict): Config of classfication and regression
prediction layer. prediction layers.
conv_cfg (dict): Config of convolution in prediction layer. conv_cfg (dict): Config of convolution in prediction layer.
norm_cfg (dict): Config of BN in prediction layer. norm_cfg (dict): Config of BN in prediction layer.
objectness_loss (dict): Config of objectness loss. objectness_loss (dict): Config of objectness loss.
...@@ -43,9 +43,9 @@ class VoteHead(nn.Module): ...@@ -43,9 +43,9 @@ class VoteHead(nn.Module):
bbox_coder, bbox_coder,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
vote_moudule_cfg=None, vote_module_cfg=None,
vote_aggregation_cfg=None, vote_aggregation_cfg=None,
feat_channels=(128, 128), pred_layer_cfg=None,
conv_cfg=dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
objectness_loss=None, objectness_loss=None,
...@@ -59,54 +59,64 @@ class VoteHead(nn.Module): ...@@ -59,54 +59,64 @@ class VoteHead(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.gt_per_seed = vote_moudule_cfg['gt_per_seed'] self.gt_per_seed = vote_module_cfg['gt_per_seed']
self.num_proposal = vote_aggregation_cfg['num_point'] self.num_proposal = vote_aggregation_cfg['num_point']
self.objectness_loss = build_loss(objectness_loss) self.objectness_loss = build_loss(objectness_loss)
self.center_loss = build_loss(center_loss) self.center_loss = build_loss(center_loss)
self.dir_class_loss = build_loss(dir_class_loss)
self.dir_res_loss = build_loss(dir_res_loss) self.dir_res_loss = build_loss(dir_res_loss)
self.size_class_loss = build_loss(size_class_loss) self.dir_class_loss = build_loss(dir_class_loss)
self.size_res_loss = build_loss(size_res_loss) self.size_res_loss = build_loss(size_res_loss)
if size_class_loss is not None:
self.size_class_loss = build_loss(size_class_loss)
if semantic_loss is not None:
self.semantic_loss = build_loss(semantic_loss) self.semantic_loss = build_loss(semantic_loss)
assert vote_aggregation_cfg['mlp_channels'][0] == vote_moudule_cfg[
'in_channels']
self.bbox_coder = build_bbox_coder(bbox_coder) self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_sizes = self.bbox_coder.num_sizes self.num_sizes = self.bbox_coder.num_sizes
self.num_dir_bins = self.bbox_coder.num_dir_bins self.num_dir_bins = self.bbox_coder.num_dir_bins
self.vote_module = VoteModule(**vote_moudule_cfg) self.vote_module = VoteModule(**vote_module_cfg)
self.vote_aggregation = build_sa_module(vote_aggregation_cfg) self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
prev_channel = vote_aggregation_cfg['mlp_channels'][-1] # Bbox classification and regression
conv_pred_list = list() self.conv_pred = BaseConvBboxHead(
for k in range(len(feat_channels)): **pred_layer_cfg,
conv_pred_list.append( num_cls_out_channels=self._get_cls_out_channels(),
ConvModule( num_reg_out_channels=self._get_reg_out_channels())
prev_channel,
feat_channels[k], def init_weights(self):
1, """Initialize weights of VoteHead."""
padding=0, pass
conv_cfg=conv_cfg,
norm_cfg=norm_cfg, def _get_cls_out_channels(self):
bias=True, """Return the channel number of classification outputs."""
inplace=True)) # Class numbers (k) + objectness (2)
prev_channel = feat_channels[k] return self.num_classes + 2
self.conv_pred = nn.Sequential(*conv_pred_list)
def _get_reg_out_channels(self):
"""Return the channel number of regression outputs."""
# Objectness scores (2), center residual (3), # Objectness scores (2), center residual (3),
# heading class+residual (num_dir_bins*2), # heading class+residual (num_dir_bins*2),
# size class+residual(num_sizes*4) # size class+residual(num_sizes*4)
conv_out_channel = (2 + 3 + self.num_dir_bins * 2 + return 3 + self.num_dir_bins * 2 + self.num_sizes * 4
self.num_sizes * 4 + num_classes)
self.conv_pred.add_module('conv_out',
nn.Conv1d(prev_channel, conv_out_channel, 1))
def init_weights(self): def _extract_input(self, feat_dict):
"""Initialize weights of VoteHead.""" """Extract inputs from features dictionary.
pass
Args:
feat_dict (dict): Feature dict from backbone.
Returns:
torch.Tensor: Coordinates of input points.
torch.Tensor: Features of input points.
torch.Tensor: Indices of input points.
"""
seed_points = feat_dict['fp_xyz'][-1]
seed_features = feat_dict['fp_features'][-1]
seed_indices = feat_dict['fp_indices'][-1]
return seed_points, seed_features, seed_indices
def forward(self, feat_dict, sample_mod): def forward(self, feat_dict, sample_mod):
"""Forward pass. """Forward pass.
...@@ -122,57 +132,74 @@ class VoteHead(nn.Module): ...@@ -122,57 +132,74 @@ class VoteHead(nn.Module):
Args: Args:
feat_dict (dict): Feature dict from backbone. feat_dict (dict): Feature dict from backbone.
sample_mod (str): Sample mode for vote aggregation layer. sample_mod (str): Sample mode for vote aggregation layer.
valid modes are "vote", "seed" and "random". valid modes are "vote", "seed", "random" and "spec".
Returns: Returns:
dict: Predictions of vote head. dict: Predictions of vote head.
""" """
assert sample_mod in ['vote', 'seed', 'random'] assert sample_mod in ['vote', 'seed', 'random', 'spec']
seed_points = feat_dict['fp_xyz'][-1] seed_points, seed_features, seed_indices = self._extract_input(
seed_features = feat_dict['fp_features'][-1] feat_dict)
seed_indices = feat_dict['fp_indices'][-1]
# 1. generate vote_points from seed_points # 1. generate vote_points from seed_points
vote_points, vote_features = self.vote_module(seed_points, vote_points, vote_features, vote_offset = self.vote_module(
seed_features) seed_points, seed_features)
results = dict( results = dict(
seed_points=seed_points, seed_points=seed_points,
seed_indices=seed_indices, seed_indices=seed_indices,
vote_points=vote_points, vote_points=vote_points,
vote_features=vote_features) vote_features=vote_features,
vote_offset=vote_offset)
# 2. aggregate vote_points # 2. aggregate vote_points
if sample_mod == 'vote': if sample_mod == 'vote':
# use fps in vote_aggregation # use fps in vote_aggregation
sample_indices = None aggregation_inputs = dict(
points_xyz=vote_points, features=vote_features)
elif sample_mod == 'seed': elif sample_mod == 'seed':
# FPS on seed and choose the votes corresponding to the seeds # FPS on seed and choose the votes corresponding to the seeds
sample_indices = furthest_point_sample(seed_points, sample_indices = furthest_point_sample(seed_points,
self.num_proposal) self.num_proposal)
aggregation_inputs = dict(
points_xyz=vote_points,
features=vote_features,
indices=sample_indices)
elif sample_mod == 'random': elif sample_mod == 'random':
# Random sampling from the votes # Random sampling from the votes
batch_size, num_seed = seed_points.shape[:2] batch_size, num_seed = seed_points.shape[:2]
sample_indices = seed_points.new_tensor( sample_indices = seed_points.new_tensor(
torch.randint(0, num_seed, (batch_size, self.num_proposal)), torch.randint(0, num_seed, (batch_size, self.num_proposal)),
dtype=torch.int32) dtype=torch.int32)
aggregation_inputs = dict(
points_xyz=vote_points,
features=vote_features,
indices=sample_indices)
elif sample_mod == 'spec':
# Specify the new center in vote_aggregation
aggregation_inputs = dict(
points_xyz=seed_points,
features=seed_features,
target_xyz=vote_points)
else: else:
raise NotImplementedError( raise NotImplementedError(
f'Sample mode {sample_mod} is not supported!') f'Sample mode {sample_mod} is not supported!')
vote_aggregation_ret = self.vote_aggregation(vote_points, vote_aggregation_ret = self.vote_aggregation(**aggregation_inputs)
vote_features,
sample_indices)
aggregated_points, features, aggregated_indices = vote_aggregation_ret aggregated_points, features, aggregated_indices = vote_aggregation_ret
results['aggregated_points'] = aggregated_points results['aggregated_points'] = aggregated_points
results['aggregated_features'] = features results['aggregated_features'] = features
results['aggregated_indices'] = aggregated_indices results['aggregated_indices'] = aggregated_indices
# 3. predict bbox and score # 3. predict bbox and score
predictions = self.conv_pred(features) cls_predictions, reg_predictions = self.conv_pred(features)
# 4. decode predictions # 4. decode predictions
decode_res = self.bbox_coder.split_pred(predictions, aggregated_points) decode_res = self.bbox_coder.split_pred(cls_predictions,
reg_predictions,
aggregated_points)
results.update(decode_res) results.update(decode_res)
return results return results
......
...@@ -4,10 +4,12 @@ from .h3dnet import H3DNet ...@@ -4,10 +4,12 @@ from .h3dnet import H3DNet
from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN
from .mvx_two_stage import MVXTwoStageDetector from .mvx_two_stage import MVXTwoStageDetector
from .parta2 import PartA2 from .parta2 import PartA2
from .ssd3dnet import SSD3DNet
from .votenet import VoteNet from .votenet import VoteNet
from .voxelnet import VoxelNet from .voxelnet import VoxelNet
__all__ = [ __all__ = [
'Base3DDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXTwoStageDetector', 'Base3DDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXTwoStageDetector',
'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet' 'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet',
'SSD3DNet'
] ]
from mmdet.models import DETECTORS
from .votenet import VoteNet
@DETECTORS.register_module()
class SSD3DNet(VoteNet):
"""3DSSDNet model.
https://arxiv.org/abs/2002.10187.pdf
"""
def __init__(self,
backbone,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(SSD3DNet, self).__init__(
backbone=backbone,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
import torch import torch
from mmcv import is_tuple_of
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from torch import nn as nn from torch import nn as nn
...@@ -15,6 +16,7 @@ class VoteModule(nn.Module): ...@@ -15,6 +16,7 @@ class VoteModule(nn.Module):
vote_per_seed (int): Number of votes generated from each seed point. vote_per_seed (int): Number of votes generated from each seed point.
gt_per_seed (int): Number of ground truth votes generated gt_per_seed (int): Number of ground truth votes generated
from each seed point. from each seed point.
num_points (int): Number of points to be used for voting.
conv_channels (tuple[int]): Out channels of vote conv_channels (tuple[int]): Out channels of vote
generating convolution. generating convolution.
conv_cfg (dict): Config of convolution. conv_cfg (dict): Config of convolution.
...@@ -23,6 +25,9 @@ class VoteModule(nn.Module): ...@@ -23,6 +25,9 @@ class VoteModule(nn.Module):
Default: dict(type='BN1d'). Default: dict(type='BN1d').
norm_feats (bool): Whether to normalize features. norm_feats (bool): Whether to normalize features.
Default: True. Default: True.
with_res_feat (bool): Whether to predict residual features.
Default: True.
vote_xyz_range (list[float], None): The range of points translation.
vote_loss (dict): Config of vote loss. vote_loss (dict): Config of vote loss.
""" """
...@@ -30,16 +35,27 @@ class VoteModule(nn.Module): ...@@ -30,16 +35,27 @@ class VoteModule(nn.Module):
in_channels, in_channels,
vote_per_seed=1, vote_per_seed=1,
gt_per_seed=3, gt_per_seed=3,
num_points=-1,
conv_channels=(16, 16), conv_channels=(16, 16),
conv_cfg=dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'),
norm_feats=True, norm_feats=True,
with_res_feat=True,
vote_xyz_range=None,
vote_loss=None): vote_loss=None):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.vote_per_seed = vote_per_seed self.vote_per_seed = vote_per_seed
self.gt_per_seed = gt_per_seed self.gt_per_seed = gt_per_seed
self.num_points = num_points
self.norm_feats = norm_feats self.norm_feats = norm_feats
self.with_res_feat = with_res_feat
assert vote_xyz_range is None or is_tuple_of(vote_xyz_range, float)
self.vote_xyz_range = vote_xyz_range
if vote_loss is not None:
self.vote_loss = build_loss(vote_loss) self.vote_loss = build_loss(vote_loss)
prev_channels = in_channels prev_channels = in_channels
...@@ -53,13 +69,17 @@ class VoteModule(nn.Module): ...@@ -53,13 +69,17 @@ class VoteModule(nn.Module):
padding=0, padding=0,
conv_cfg=conv_cfg, conv_cfg=conv_cfg,
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
act_cfg=act_cfg,
bias=True, bias=True,
inplace=True)) inplace=True))
prev_channels = conv_channels[k] prev_channels = conv_channels[k]
self.vote_conv = nn.Sequential(*vote_conv_list) self.vote_conv = nn.Sequential(*vote_conv_list)
# conv_out predicts coordinate and residual features # conv_out predicts coordinate and residual features
if with_res_feat:
out_channel = (3 + in_channels) * self.vote_per_seed out_channel = (3 + in_channels) * self.vote_per_seed
else:
out_channel = 3 * self.vote_per_seed
self.conv_out = nn.Conv1d(prev_channels, out_channel, 1) self.conv_out = nn.Conv1d(prev_channels, out_channel, 1)
def forward(self, seed_points, seed_feats): def forward(self, seed_points, seed_feats):
...@@ -80,6 +100,13 @@ class VoteModule(nn.Module): ...@@ -80,6 +100,13 @@ class VoteModule(nn.Module):
shape (B, C, M) where ``M=num_seed*vote_per_seed``, \ shape (B, C, M) where ``M=num_seed*vote_per_seed``, \
``C=vote_feature_dim``. ``C=vote_feature_dim``.
""" """
if self.num_points != -1:
assert self.num_points < seed_points.shape[1], \
f'Number of vote points ({self.num_points}) should be '\
f'smaller than seed points size ({seed_points.shape[1]})'
seed_points = seed_points[:, :self.num_points]
seed_feats = seed_feats[..., :self.num_points]
batch_size, feat_channels, num_seed = seed_feats.shape batch_size, feat_channels, num_seed = seed_feats.shape
num_vote = num_seed * self.vote_per_seed num_vote = num_seed * self.vote_per_seed
x = self.vote_conv(seed_feats) x = self.vote_conv(seed_feats)
...@@ -88,21 +115,36 @@ class VoteModule(nn.Module): ...@@ -88,21 +115,36 @@ class VoteModule(nn.Module):
votes = votes.transpose(2, 1).view(batch_size, num_seed, votes = votes.transpose(2, 1).view(batch_size, num_seed,
self.vote_per_seed, -1) self.vote_per_seed, -1)
offset = votes[:, :, :, 0:3].contiguous()
res_feats = votes[:, :, :, 3:].contiguous()
vote_points = seed_points.unsqueeze(2) + offset offset = votes[:, :, :, 0:3]
if self.vote_xyz_range is not None:
limited_offset_list = []
for axis in range(len(self.vote_xyz_range)):
limited_offset_list.append(offset[..., axis].clamp(
min=-self.vote_xyz_range[axis],
max=self.vote_xyz_range[axis]))
limited_offset = torch.stack(limited_offset_list, -1)
vote_points = (seed_points.unsqueeze(2) +
limited_offset).contiguous()
else:
vote_points = (seed_points.unsqueeze(2) + offset).contiguous()
vote_points = vote_points.view(batch_size, num_vote, 3) vote_points = vote_points.view(batch_size, num_vote, 3)
vote_feats = seed_feats.permute( offset = offset.reshape(batch_size, num_vote, 3).transpose(2, 1)
0, 2, 1).unsqueeze(2).contiguous() + res_feats
vote_feats = vote_feats.view(batch_size, num_vote, if self.with_res_feat:
feat_channels).transpose(2, res_feats = votes[:, :, :, 3:]
1).contiguous() vote_feats = (seed_feats.transpose(2, 1).unsqueeze(2) +
res_feats).contiguous()
vote_feats = vote_feats.view(batch_size,
num_vote, feat_channels).transpose(
2, 1).contiguous()
if self.norm_feats: if self.norm_feats:
features_norm = torch.norm(vote_feats, p=2, dim=1) features_norm = torch.norm(vote_feats, p=2, dim=1)
vote_feats = vote_feats.div(features_norm.unsqueeze(1)) vote_feats = vote_feats.div(features_norm.unsqueeze(1))
return vote_points, vote_feats else:
vote_feats = seed_feats
return vote_points, vote_feats, offset
def get_loss(self, seed_points, vote_points, seed_indices, def get_loss(self, seed_points, vote_points, seed_indices,
vote_targets_mask, vote_targets): vote_targets_mask, vote_targets):
......
...@@ -308,8 +308,9 @@ class H3DBboxHead(nn.Module): ...@@ -308,8 +308,9 @@ class H3DBboxHead(nn.Module):
for conv_module in self.bbox_pred[1:]: for conv_module in self.bbox_pred[1:]:
bbox_predictions = conv_module(bbox_predictions) bbox_predictions = conv_module(bbox_predictions)
refine_decode_res = self.bbox_coder.split_pred(bbox_predictions, refine_decode_res = self.bbox_coder.split_pred(
aggregated_points) bbox_predictions[:, :self.num_classes + 2],
bbox_predictions[:, self.num_classes + 2:], aggregated_points)
for key in refine_decode_res.keys(): for key in refine_decode_res.keys():
ret_dict[key + '_optimized'] = refine_decode_res[key] ret_dict[key + '_optimized'] = refine_decode_res[key]
return ret_dict return ret_dict
......
...@@ -23,7 +23,7 @@ class PrimitiveHead(nn.Module): ...@@ -23,7 +23,7 @@ class PrimitiveHead(nn.Module):
decoding boxes. decoding boxes.
train_cfg (dict): Config for training. train_cfg (dict): Config for training.
test_cfg (dict): Config for testing. test_cfg (dict): Config for testing.
vote_moudule_cfg (dict): Config of VoteModule for point-wise votes. vote_module_cfg (dict): Config of VoteModule for point-wise votes.
vote_aggregation_cfg (dict): Config of vote aggregation layer. vote_aggregation_cfg (dict): Config of vote aggregation layer.
feat_channels (tuple[int]): Convolution channels of feat_channels (tuple[int]): Convolution channels of
prediction layer. prediction layer.
...@@ -42,7 +42,7 @@ class PrimitiveHead(nn.Module): ...@@ -42,7 +42,7 @@ class PrimitiveHead(nn.Module):
primitive_mode, primitive_mode,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
vote_moudule_cfg=None, vote_module_cfg=None,
vote_aggregation_cfg=None, vote_aggregation_cfg=None,
feat_channels=(128, 128), feat_channels=(128, 128),
upper_thresh=100.0, upper_thresh=100.0,
...@@ -61,7 +61,7 @@ class PrimitiveHead(nn.Module): ...@@ -61,7 +61,7 @@ class PrimitiveHead(nn.Module):
self.primitive_mode = primitive_mode self.primitive_mode = primitive_mode
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.gt_per_seed = vote_moudule_cfg['gt_per_seed'] self.gt_per_seed = vote_module_cfg['gt_per_seed']
self.num_proposal = vote_aggregation_cfg['num_point'] self.num_proposal = vote_aggregation_cfg['num_point']
self.upper_thresh = upper_thresh self.upper_thresh = upper_thresh
self.surface_thresh = surface_thresh self.surface_thresh = surface_thresh
...@@ -71,13 +71,13 @@ class PrimitiveHead(nn.Module): ...@@ -71,13 +71,13 @@ class PrimitiveHead(nn.Module):
self.semantic_reg_loss = build_loss(semantic_reg_loss) self.semantic_reg_loss = build_loss(semantic_reg_loss)
self.semantic_cls_loss = build_loss(semantic_cls_loss) self.semantic_cls_loss = build_loss(semantic_cls_loss)
assert vote_aggregation_cfg['mlp_channels'][0] == vote_moudule_cfg[ assert vote_aggregation_cfg['mlp_channels'][0] == vote_module_cfg[
'in_channels'] 'in_channels']
# Primitive existence flag prediction # Primitive existence flag prediction
self.flag_conv = ConvModule( self.flag_conv = ConvModule(
vote_moudule_cfg['conv_channels'][-1], vote_module_cfg['conv_channels'][-1],
vote_moudule_cfg['conv_channels'][-1] // 2, vote_module_cfg['conv_channels'][-1] // 2,
1, 1,
padding=0, padding=0,
conv_cfg=conv_cfg, conv_cfg=conv_cfg,
...@@ -85,9 +85,9 @@ class PrimitiveHead(nn.Module): ...@@ -85,9 +85,9 @@ class PrimitiveHead(nn.Module):
bias=True, bias=True,
inplace=True) inplace=True)
self.flag_pred = torch.nn.Conv1d( self.flag_pred = torch.nn.Conv1d(
vote_moudule_cfg['conv_channels'][-1] // 2, 2, 1) vote_module_cfg['conv_channels'][-1] // 2, 2, 1)
self.vote_module = VoteModule(**vote_moudule_cfg) self.vote_module = VoteModule(**vote_module_cfg)
self.vote_aggregation = build_sa_module(vote_aggregation_cfg) self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
prev_channel = vote_aggregation_cfg['mlp_channels'][-1] prev_channel = vote_aggregation_cfg['mlp_channels'][-1]
...@@ -137,8 +137,8 @@ class PrimitiveHead(nn.Module): ...@@ -137,8 +137,8 @@ class PrimitiveHead(nn.Module):
results['pred_flag_' + self.primitive_mode] = primitive_flag results['pred_flag_' + self.primitive_mode] = primitive_flag
# 1. generate vote_points from seed_points # 1. generate vote_points from seed_points
vote_points, vote_features = self.vote_module(seed_points, vote_points, vote_features, _ = self.vote_module(
seed_features) seed_points, seed_features)
results['vote_' + self.primitive_mode] = vote_points results['vote_' + self.primitive_mode] = vote_points
results['vote_features_' + self.primitive_mode] = vote_features results['vote_features_' + self.primitive_mode] = vote_features
......
import torch import torch
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes, LiDARInstance3DBoxes
from mmdet.core import build_bbox_coder from mmdet.core import build_bbox_coder
...@@ -194,9 +194,10 @@ def test_partial_bin_based_box_coder(): ...@@ -194,9 +194,10 @@ def test_partial_bin_based_box_coder():
assert torch.allclose(bbox3d, expected_bbox3d, atol=1e-4) assert torch.allclose(bbox3d, expected_bbox3d, atol=1e-4)
# test split_pred # test split_pred
box_preds = torch.rand(2, 79, 256) cls_preds = torch.rand(2, 12, 256)
reg_preds = torch.rand(2, 67, 256)
base_xyz = torch.rand(2, 256, 3) base_xyz = torch.rand(2, 256, 3)
results = box_coder.split_pred(box_preds, base_xyz) results = box_coder.split_pred(cls_preds, reg_preds, base_xyz)
obj_scores = results['obj_scores'] obj_scores = results['obj_scores']
center = results['center'] center = results['center']
dir_class = results['dir_class'] dir_class = results['dir_class']
...@@ -215,3 +216,110 @@ def test_partial_bin_based_box_coder(): ...@@ -215,3 +216,110 @@ def test_partial_bin_based_box_coder():
assert size_res_norm.shape == torch.Size([2, 256, 10, 3]) assert size_res_norm.shape == torch.Size([2, 256, 10, 3])
assert size_res.shape == torch.Size([2, 256, 10, 3]) assert size_res.shape == torch.Size([2, 256, 10, 3])
assert sem_scores.shape == torch.Size([2, 256, 10]) assert sem_scores.shape == torch.Size([2, 256, 10])
def test_anchor_free_box_coder():
box_coder_cfg = dict(
type='AnchorFreeBBoxCoder', num_dir_bins=12, with_rot=True)
box_coder = build_bbox_coder(box_coder_cfg)
# test encode
gt_bboxes = LiDARInstance3DBoxes([[
2.1227e+00, 5.7951e+00, -9.9900e-01, 1.6736e+00, 4.2419e+00,
1.5473e+00, -1.5501e+00
],
[
1.1791e+01, 9.0276e+00, -8.5772e-01,
1.6210e+00, 3.5367e+00, 1.4841e+00,
-1.7369e+00
],
[
2.3638e+01, 9.6997e+00, -5.6713e-01,
1.7578e+00, 4.6103e+00, 1.5999e+00,
-1.4556e+00
]])
gt_labels = torch.tensor([0, 0, 0])
(center_targets, size_targets, dir_class_targets,
dir_res_targets) = box_coder.encode(gt_bboxes, gt_labels)
expected_center_target = torch.tensor([[2.1227, 5.7951, -0.2253],
[11.7908, 9.0276, -0.1156],
[23.6380, 9.6997, 0.2328]])
expected_size_targets = torch.tensor([[0.8368, 2.1210, 0.7736],
[0.8105, 1.7683, 0.7421],
[0.8789, 2.3052, 0.8000]])
expected_dir_class_target = torch.tensor([9, 9, 9])
expected_dir_res_target = torch.tensor([0.0394, -0.3172, 0.2199])
assert torch.allclose(center_targets, expected_center_target, atol=1e-4)
assert torch.allclose(size_targets, expected_size_targets, atol=1e-4)
assert torch.all(dir_class_targets == expected_dir_class_target)
assert torch.allclose(dir_res_targets, expected_dir_res_target, atol=1e-3)
# test decode
center = torch.tensor([[[14.5954, 6.3312, 0.7671],
[67.5245, 22.4422, 1.5610],
[47.7693, -6.7980, 1.4395]]])
size_res = torch.tensor([[[-1.0752, 1.8760, 0.7715],
[-0.8016, 1.1754, 0.0102],
[-1.2789, 0.5948, 0.4728]]])
dir_class = torch.tensor([[[
0.1512, 1.7914, -1.7658, 2.1572, -0.9215, 1.2139, 0.1749, 0.8606,
1.1743, -0.7679, -1.6005, 0.4623
],
[
-0.3957, 1.2026, -1.2677, 1.3863, -0.5754,
1.7083, 0.2601, 0.1129, 0.7146, -0.1367,
-1.2892, -0.0083
],
[
-0.8862, 1.2050, -1.3881, 1.6604, -0.9087,
1.1907, -0.0280, 0.2027, 1.0644, -0.7205,
-1.0738, 0.4748
]]])
dir_res = torch.tensor([[[
1.1151, 0.5535, -0.2053, -0.6582, -0.1616, -0.1821, 0.4675, 0.6621,
0.8146, -0.0448, -0.7253, -0.7171
],
[
0.7888, 0.2478, -0.1962, -0.7267, 0.0573,
-0.2398, 0.6984, 0.5859, 0.7507, -0.1980,
-0.6538, -0.6602
],
[
0.9039, 0.6109, 0.1960, -0.5016, 0.0551,
-0.4086, 0.3398, 0.2759, 0.7247, -0.0655,
-0.5052, -0.9026
]]])
bbox_out = dict(
center=center, size=size_res, dir_class=dir_class, dir_res=dir_res)
bbox3d = box_coder.decode(bbox_out)
expected_bbox3d = torch.tensor(
[[[14.5954, 6.3312, 0.7671, 0.1000, 3.7521, 1.5429, 0.9126],
[67.5245, 22.4422, 1.5610, 0.1000, 2.3508, 0.1000, 2.3782],
[47.7693, -6.7980, 1.4395, 0.1000, 1.1897, 0.9456, 1.0692]]])
assert torch.allclose(bbox3d, expected_bbox3d, atol=1e-4)
# test split_pred
cls_preds = torch.rand(2, 1, 256)
reg_preds = torch.rand(2, 30, 256)
base_xyz = torch.rand(2, 256, 3)
results = box_coder.split_pred(cls_preds, reg_preds, base_xyz)
obj_scores = results['obj_scores']
center = results['center']
center_offset = results['center_offset']
dir_class = results['dir_class']
dir_res_norm = results['dir_res_norm']
dir_res = results['dir_res']
size = results['size']
assert obj_scores.shape == torch.Size([2, 1, 256])
assert center.shape == torch.Size([2, 256, 3])
assert center_offset.shape == torch.Size([2, 256, 3])
assert dir_class.shape == torch.Size([2, 256, 12])
assert dir_res_norm.shape == torch.Size([2, 256, 12])
assert dir_res.shape == torch.Size([2, 256, 12])
assert size.shape == torch.Size([2, 256, 3])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment