Unverified Commit 460f6b3b authored by encore-zhou's avatar encore-zhou Committed by GitHub
Browse files

[feature]: support ssd_3d_head in 3DSSD (#83)

* add ssd3dhead

* fix bugs for anchorfreebboxcoder

* modify ssd 3d head

* modify ssd 3d head

* reconstruct ssd3dhead and votehead

* add unittest

* modify 3dssd config

* modify 3dssd head

* modify 3dssd head

* rename base conv bbox head

* modify vote module

* modify 3dssd config

* fix bugs for unittest

* modify test_heads.py

* fix bugs for h3d bbox head

* add 3dssd detector

* fix bugs for 3dssd config

* modify base conv bbox head

* modify base conv bbox head

* modify base conv bbox head
parent 0e2ad8df
_base_ = [
'../_base_/models/3dssd.py', '../_base_/datasets/kitti-3d-car.py',
'../_base_/default_runtime.py'
]
# dataset settings
dataset_type = 'KittiDataset'
data_root = 'data/kitti/'
class_names = ['Car']
point_cloud_range = [0, -40, -5, 70, 40, 3]
input_modality = dict(use_lidar=True, use_camera=False)
db_sampler = dict(
data_root=data_root,
info_path=data_root + 'kitti_dbinfos_train.pkl',
rate=1.0,
prepare=dict(filter_by_difficulty=[-1], filter_by_min_points=dict(Car=5)),
classes=class_names,
sample_groups=dict(Car=15))
file_client_args = dict(backend='disk')
# Uncomment the following if use ceph or other file clients.
# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient
# for more details.
# file_client_args = dict(
# backend='petrel', path_mapping=dict(data='s3://kitti_data/'))
train_pipeline = [
dict(
type='LoadPointsFromFile',
load_dim=4,
use_dim=4,
file_client_args=file_client_args),
dict(
type='LoadAnnotations3D',
with_bbox_3d=True,
with_label_3d=True,
file_client_args=file_client_args),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectSample', db_sampler=db_sampler),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(
type='ObjectNoise',
num_try=100,
translation_std=[1.0, 1.0, 0],
global_rot_range=[0.0, 0.0],
rot_range=[-1.0471975511965976, 1.0471975511965976]),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.9, 1.1]),
dict(type='BackgroundPointsFilter', bbox_enlarge_range=(0.5, 2.0, 0.5)),
dict(type='IndoorPointSample', num_points=16384),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
load_dim=4,
use_dim=4,
file_client_args=file_client_args),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='IndoorPointSample', num_points=16384),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(dataset=dict(pipeline=train_pipeline)),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
evaluation = dict(interval=2)
# model settings
model = dict(
bbox_head=dict(
num_classes=1,
bbox_coder=dict(
type='AnchorFreeBBoxCoder', num_dir_bins=12, with_rot=True)))
# optimizer
lr = 0.002 # max learning rate
optimizer = dict(type='AdamW', lr=lr, weight_decay=0)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
lr_config = dict(policy='step', warmup=None, step=[80, 120])
# runtime settings
total_epochs = 150
# yapf:disable
log_config = dict(
interval=30,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')
])
# yapf:enable
model = dict(
type='SSD3DNet',
backbone=dict(
type='PointNet2SAMSG',
in_channels=4,
num_points=(4096, 512, (256, 256)),
radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)),
num_samples=((32, 32, 64), (32, 32, 64), (32, 32, 32)),
sa_channels=(((16, 16, 32), (16, 16, 32), (32, 32, 64)),
((64, 64, 128), (64, 64, 128), (64, 96, 128)),
((128, 128, 256), (128, 192, 256), (128, 256, 256))),
aggregation_channels=(64, 128, 256),
fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')),
fps_sample_range_lists=((-1), (-1), (512, -1)),
norm_cfg=dict(type='BN2d', eps=1e-3, momentum=0.1),
sa_cfg=dict(
type='PointSAModuleMSG',
pool_mod='max',
use_xyz=True,
normalize_xyz=False)),
bbox_head=dict(
type='SSD3DHead',
in_channels=256,
vote_module_cfg=dict(
in_channels=256,
num_points=256,
gt_per_seed=1,
conv_channels=(128, ),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.1),
with_res_feat=False,
vote_xyz_range=(3.0, 3.0, 2.0)),
vote_aggregation_cfg=dict(
type='PointSAModuleMSG',
num_point=256,
radii=(4.8, 6.4),
sample_nums=(16, 32),
mlp_channels=((256, 256, 256, 512), (256, 256, 512, 1024)),
norm_cfg=dict(type='BN2d', eps=1e-3, momentum=0.1),
use_xyz=True,
normalize_xyz=False,
bias=True),
pred_layer_cfg=dict(
in_channels=1536,
shared_conv_channels=(512, 128),
cls_conv_layers=(128, ),
reg_conv_layers=(128, ),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.1),
bias=True),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.1),
objectness_loss=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
center_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=1.0),
dir_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
dir_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=1.0),
size_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=1.0),
corner_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=1.0),
vote_loss=dict(type='SmoothL1Loss', reduction='sum', loss_weight=1.0)))
# model training and testing settings
train_cfg = dict(
sample_mod='spec', pos_distance_thr=10.0, expand_dims_length=0.05)
test_cfg = dict(
nms_cfg=dict(type='nms', iou_thr=0.1),
sample_mod='spec',
score_thr=0.0,
per_class_proposal=True,
max_output_num=100)
# optimizer
# This schedule is mainly used by models on indoor dataset,
# e.g., VoteNet on SUNRGBD and ScanNet
lr = 0.002 # max learning rate
optimizer = dict(type='AdamW', lr=lr, weight_decay=0)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
lr_config = dict(policy='step', warmup=None, step=[80, 120])
# runtime settings
total_epochs = 150
......@@ -5,7 +5,7 @@ primitive_z_cfg = dict(
primitive_mode='z',
upper_thresh=100.0,
surface_thresh=0.5,
vote_moudule_cfg=dict(
vote_module_cfg=dict(
in_channels=256,
vote_per_seed=1,
gt_per_seed=1,
......@@ -63,7 +63,7 @@ primitive_xy_cfg = dict(
primitive_mode='xy',
upper_thresh=100.0,
surface_thresh=0.5,
vote_moudule_cfg=dict(
vote_module_cfg=dict(
in_channels=256,
vote_per_seed=1,
gt_per_seed=1,
......@@ -121,7 +121,7 @@ primitive_line_cfg = dict(
primitive_mode='line',
upper_thresh=100.0,
surface_thresh=0.5,
vote_moudule_cfg=dict(
vote_module_cfg=dict(
in_channels=256,
vote_per_seed=1,
gt_per_seed=1,
......@@ -198,7 +198,7 @@ model = dict(
normalize_xyz=True))),
rpn_head=dict(
type='VoteHead',
vote_moudule_cfg=dict(
vote_module_cfg=dict(
in_channels=256,
vote_per_seed=1,
gt_per_seed=3,
......@@ -219,7 +219,7 @@ model = dict(
mlp_channels=[256, 128, 128, 128],
use_xyz=True,
normalize_xyz=True),
feat_channels=(128, 128),
pred_layer_cfg=dict(in_channels=128, shared_conv_channels=(128, 128)),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
......
......@@ -17,7 +17,7 @@ model = dict(
normalize_xyz=True)),
bbox_head=dict(
type='VoteHead',
vote_moudule_cfg=dict(
vote_module_cfg=dict(
in_channels=256,
vote_per_seed=1,
gt_per_seed=3,
......@@ -38,7 +38,7 @@ model = dict(
mlp_channels=[256, 128, 128, 128],
use_xyz=True,
normalize_xyz=True),
feat_channels=(128, 128),
pred_layer_cfg=dict(in_channels=128, shared_conv_channels=(128, 128)),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
......
from mmdet.core.bbox import build_bbox_coder
from .anchor_free_bbox_coder import AnchorFreeBBoxCoder
from .delta_xyzwhlr_bbox_coder import DeltaXYZWLHRBBoxCoder
from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder
__all__ = [
'build_bbox_coder', 'DeltaXYZWLHRBBoxCoder', 'PartialBinBasedBBoxCoder'
'build_bbox_coder', 'DeltaXYZWLHRBBoxCoder', 'PartialBinBasedBBoxCoder',
'AnchorFreeBBoxCoder'
]
import numpy as np
import torch
from mmdet.core.bbox.builder import BBOX_CODERS
from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder
@BBOX_CODERS.register_module()
class AnchorFreeBBoxCoder(PartialBinBasedBBoxCoder):
"""Anchor free bbox coder for 3D boxes.
Args:
num_dir_bins (int): Number of bins to encode direction angle.
with_rot (bool): Whether the bbox is with rotation.
"""
def __init__(self, num_dir_bins, with_rot=True):
super(AnchorFreeBBoxCoder, self).__init__(
num_dir_bins, 0, [], with_rot=with_rot)
self.num_dir_bins = num_dir_bins
self.with_rot = with_rot
def encode(self, gt_bboxes_3d, gt_labels_3d):
"""Encode ground truth to prediction targets.
Args:
gt_bboxes_3d (BaseInstance3DBoxes): Ground truth bboxes \
with shape (n, 7).
gt_labels_3d (torch.Tensor): Ground truth classes.
Returns:
tuple: Targets of center, size and direction.
"""
# generate center target
center_target = gt_bboxes_3d.gravity_center
# generate bbox size target
size_res_target = gt_bboxes_3d.dims / 2
# generate dir target
box_num = gt_labels_3d.shape[0]
if self.with_rot:
(dir_class_target,
dir_res_target) = self.angle2class(gt_bboxes_3d.yaw)
dir_res_target /= (2 * np.pi / self.num_dir_bins)
else:
dir_class_target = gt_labels_3d.new_zeros(box_num)
dir_res_target = gt_bboxes_3d.tensor.new_zeros(box_num)
return (center_target, size_res_target, dir_class_target,
dir_res_target)
def decode(self, bbox_out):
"""Decode predicted parts to bbox3d.
Args:
bbox_out (dict): Predictions from model, should contain keys below.
- center: predicted bottom center of bboxes.
- dir_class: predicted bbox direction class.
- dir_res: predicted bbox direction residual.
- size: predicted bbox size.
Returns:
torch.Tensor: Decoded bbox3d with shape (batch, n, 7).
"""
center = bbox_out['center']
batch_size, num_proposal = center.shape[:2]
# decode heading angle
if self.with_rot:
dir_class = torch.argmax(bbox_out['dir_class'], -1)
dir_res = torch.gather(bbox_out['dir_res'], 2,
dir_class.unsqueeze(-1))
dir_res.squeeze_(2)
dir_angle = self.class2angle(dir_class, dir_res).reshape(
batch_size, num_proposal, 1)
else:
dir_angle = center.new_zeros(batch_size, num_proposal, 1)
# decode bbox size
bbox_size = torch.clamp(bbox_out['size'] * 2, min=0.1)
bbox3d = torch.cat([center, bbox_size, dir_angle], dim=-1)
return bbox3d
def split_pred(self, cls_preds, reg_preds, base_xyz):
"""Split predicted features to specific parts.
Args:
cls_preds (torch.Tensor): Class predicted features to split.
reg_preds (torch.Tensor): Regression predicted features to split.
base_xyz (torch.Tensor): Coordinates of points.
Returns:
dict[str, torch.Tensor]: Split results.
"""
results = {}
results['obj_scores'] = cls_preds
start, end = 0, 0
reg_preds_trans = reg_preds.transpose(2, 1)
# decode center
end += 3
# (batch_size, num_proposal, 3)
results['center_offset'] = reg_preds_trans[..., start:end]
results['center'] = base_xyz.detach() + reg_preds_trans[..., start:end]
start = end
# decode center
end += 3
# (batch_size, num_proposal, 3)
results['size'] = reg_preds_trans[..., start:end]
start = end
# decode direction
end += self.num_dir_bins
results['dir_class'] = reg_preds_trans[..., start:end]
start = end
end += self.num_dir_bins
dir_res_norm = reg_preds_trans[..., start:end]
start = end
results['dir_res_norm'] = dir_res_norm
results['dir_res'] = dir_res_norm * (2 * np.pi / self.num_dir_bins)
return results
......@@ -98,11 +98,12 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder):
bbox3d = torch.cat([center, bbox_size, dir_angle], dim=-1)
return bbox3d
def split_pred(self, preds, base_xyz):
def split_pred(self, cls_preds, reg_preds, base_xyz):
"""Split predicted features to specific parts.
Args:
preds (torch.Tensor): Predicted features to split.
cls_preds (torch.Tensor): Class predicted features to split.
reg_preds (torch.Tensor): Regression predicted features to split.
base_xyz (torch.Tensor): Coordinates of points.
Returns:
......@@ -110,26 +111,24 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder):
"""
results = {}
start, end = 0, 0
preds_trans = preds.transpose(2, 1)
# decode objectness score
end += 2
results['obj_scores'] = preds_trans[..., start:end].contiguous()
start = end
cls_preds_trans = cls_preds.transpose(2, 1)
reg_preds_trans = reg_preds.transpose(2, 1)
# decode center
end += 3
# (batch_size, num_proposal, 3)
results['center'] = base_xyz + preds_trans[..., start:end].contiguous()
results['center'] = base_xyz + \
reg_preds_trans[..., start:end].contiguous()
start = end
# decode direction
end += self.num_dir_bins
results['dir_class'] = preds_trans[..., start:end].contiguous()
results['dir_class'] = reg_preds_trans[..., start:end].contiguous()
start = end
end += self.num_dir_bins
dir_res_norm = preds_trans[..., start:end].contiguous()
dir_res_norm = reg_preds_trans[..., start:end].contiguous()
start = end
results['dir_res_norm'] = dir_res_norm
......@@ -137,23 +136,29 @@ class PartialBinBasedBBoxCoder(BaseBBoxCoder):
# decode size
end += self.num_sizes
results['size_class'] = preds_trans[..., start:end].contiguous()
results['size_class'] = reg_preds_trans[..., start:end].contiguous()
start = end
end += self.num_sizes * 3
size_res_norm = preds_trans[..., start:end]
batch_size, num_proposal = preds_trans.shape[:2]
size_res_norm = reg_preds_trans[..., start:end]
batch_size, num_proposal = reg_preds_trans.shape[:2]
size_res_norm = size_res_norm.view(
[batch_size, num_proposal, self.num_sizes, 3])
start = end
results['size_res_norm'] = size_res_norm.contiguous()
mean_sizes = preds.new_tensor(self.mean_sizes)
mean_sizes = reg_preds.new_tensor(self.mean_sizes)
results['size_res'] = (
size_res_norm * mean_sizes.unsqueeze(0).unsqueeze(0))
# decode objectness score
start = 0
end = 2
results['obj_scores'] = cls_preds_trans[..., start:end].contiguous()
start = end
# decode semantic score
results['sem_scores'] = preds_trans[..., start:].contiguous()
results['sem_scores'] = cls_preds_trans[..., start:].contiguous()
return results
......
......@@ -252,6 +252,21 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
return box_idxs_of_pts.squeeze(0)
def enlarged_box(self, extra_width):
"""Enlarge the length, width and height boxes.
Args:
extra_width (float | torch.Tensor): Extra width to enlarge the box.
Returns:
:obj:`LiDARInstance3DBoxes`: Enlarged boxes.
"""
enlarged_boxes = self.tensor.clone()
enlarged_boxes[:, 3:6] += extra_width * 2
# bottom center z minus extra_width
enlarged_boxes[:, 2] -= extra_width
return self.new_box(enlarged_boxes)
def get_surface_line_center(self):
"""Compute surface and line center of bounding boxes.
......
......@@ -475,6 +475,7 @@ def eval_class(gt_annos,
if num_examples < num_parts:
num_parts = num_examples
split_parts = get_split_parts(num_examples, num_parts)
rets = calculate_iou_partly(dt_annos, gt_annos, metric, num_parts)
overlaps, parted_overlaps, total_dt_num, total_gt_num = rets
N_SAMPLE_PTS = 41
......
......@@ -24,6 +24,7 @@ class PointNet2SAMSG(BasePointNet):
fps_mods (tuple[int]): Mod of FPS for each SA module.
fps_sample_range_lists (tuple[tuple[int]]): The number of sampling
points which each SA module samples.
dilated_group (tuple[bool]): Whether to use dilated ball query for
out_indices (Sequence[int]): Output from which stages.
norm_cfg (dict): Config of normalization layer.
sa_cfg (dict): Config of set abstraction module, which may contain
......@@ -47,13 +48,14 @@ class PointNet2SAMSG(BasePointNet):
aggregation_channels=(64, 128, 256),
fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')),
fps_sample_range_lists=((-1), (-1), (512, -1)),
dilated_group=(True, True, True),
out_indices=(2, ),
norm_cfg=dict(type='BN2d'),
sa_cfg=dict(
type='PointSAModuleMSG',
pool_mod='max',
use_xyz=True,
normalize_xyz=True)):
normalize_xyz=False)):
super().__init__()
self.num_sa = len(sa_channels)
self.out_indices = out_indices
......@@ -94,6 +96,7 @@ class PointNet2SAMSG(BasePointNet):
mlp_channels=cur_sa_mlps,
fps_mod=cur_fps_mod,
fps_sample_range_list=cur_fps_sample_range_list,
dilated_group=dilated_group[sa_index],
norm_cfg=norm_cfg,
cfg=sa_cfg,
bias=True))
......@@ -137,6 +140,7 @@ class PointNet2SAMSG(BasePointNet):
out_sa_xyz = []
out_sa_features = []
out_sa_indices = []
for i in range(self.num_sa):
cur_xyz, cur_features, cur_indices = self.SA_modules[i](
sa_xyz[i], sa_features[i])
......
from .anchor3d_head import Anchor3DHead
from .base_conv_bbox_head import BaseConvBboxHead
from .free_anchor3d_head import FreeAnchor3DHead
from .parta2_rpn_head import PartA2RPNHead
from .ssd_3d_head import SSD3DHead
from .vote_head import VoteHead
__all__ = ['Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead']
__all__ = [
'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead',
'SSD3DHead', 'BaseConvBboxHead'
]
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import build_conv_layer
from torch import nn as nn
from mmdet.models.builder import HEADS
@HEADS.register_module()
class BaseConvBboxHead(nn.Module):
r"""More general bbox head, with shared conv layers and two optional
separated branches.
.. code-block:: none
/-> cls convs -> cls_score
shared convs
\-> reg convs -> bbox_pred
"""
def __init__(self,
in_channels=0,
shared_conv_channels=(),
cls_conv_channels=(),
num_cls_out_channels=0,
reg_conv_channels=(),
num_reg_out_channels=0,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'),
bias='auto',
*args,
**kwargs):
super(BaseConvBboxHead, self).__init__(*args, **kwargs)
assert in_channels > 0
assert num_cls_out_channels > 0
assert num_reg_out_channels > 0
self.in_channels = in_channels
self.shared_conv_channels = shared_conv_channels
self.cls_conv_channels = cls_conv_channels
self.num_cls_out_channels = num_cls_out_channels
self.reg_conv_channels = reg_conv_channels
self.num_reg_out_channels = num_reg_out_channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.bias = bias
# add shared convs
if len(self.shared_conv_channels) > 0:
self.shared_convs = self._add_conv_branch(
self.in_channels, self.shared_conv_channels)
out_channels = self.shared_conv_channels[-1]
else:
out_channels = self.in_channels
# add cls specific branch
prev_channel = out_channels
if len(self.cls_conv_channels) > 0:
self.cls_convs = self._add_conv_branch(prev_channel,
self.cls_conv_channels)
prev_channel = self.cls_conv_channels[-1]
self.conv_cls = build_conv_layer(
conv_cfg,
in_channels=prev_channel,
out_channels=num_cls_out_channels,
kernel_size=1)
# add reg specific branch
prev_channel = out_channels
if len(self.reg_conv_channels) > 0:
self.reg_convs = self._add_conv_branch(prev_channel,
self.reg_conv_channels)
prev_channel = self.reg_conv_channels[-1]
self.conv_reg = build_conv_layer(
conv_cfg,
in_channels=prev_channel,
out_channels=num_reg_out_channels,
kernel_size=1)
def _add_conv_branch(self, in_channels, conv_channels):
"""Add shared or separable branch."""
conv_spec = [in_channels] + list(conv_channels)
# add branch specific conv layers
conv_layers = nn.Sequential()
for i in range(len(conv_spec) - 1):
conv_layers.add_module(
f'layer{i}',
ConvModule(
conv_spec[i],
conv_spec[i + 1],
kernel_size=1,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
bias=self.bias,
inplace=True))
return conv_layers
def init_weights(self):
# conv layers are already initialized by ConvModule
pass
def forward(self, feats):
"""Forward.
Args:
feats (Tensor): Input features
Returns:
Tensor: Class scores predictions
Tensor: Regression predictions
"""
# shared part
if len(self.shared_conv_channels) > 0:
x = self.shared_convs(feats)
# separate branches
x_cls = x
x_reg = x
if len(self.cls_conv_channels) > 0:
x_cls = self.cls_convs(x_cls)
cls_score = self.conv_cls(x_cls)
if len(self.reg_conv_channels) > 0:
x_reg = self.reg_convs(x_reg)
bbox_pred = self.conv_reg(x_reg)
return cls_score, bbox_pred
import torch
from mmcv.ops.nms import batched_nms
from torch.nn import functional as F
from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
LiDARInstance3DBoxes,
rotation_3d_in_axis)
from mmdet3d.models.builder import build_loss
from mmdet.core import multi_apply
from mmdet.models import HEADS
from .vote_head import VoteHead
@HEADS.register_module()
class SSD3DHead(VoteHead):
r"""Bbox head of `3DSSD <https://arxiv.org/abs/2002.10187>`_.
Args:
num_classes (int): The number of class.
bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
decoding boxes.
in_channels (int): The number of input feature channel.
train_cfg (dict): Config for training.
test_cfg (dict): Config for testing.
vote_module_cfg (dict): Config of VoteModule for point-wise votes.
vote_aggregation_cfg (dict): Config of vote aggregation layer.
pred_layer_cfg (dict): Config of classfication and regression
prediction layers.
conv_cfg (dict): Config of convolution in prediction layer.
norm_cfg (dict): Config of BN in prediction layer.
act_cfg (dict): Config of activation in prediction layer.
objectness_loss (dict): Config of objectness loss.
center_loss (dict): Config of center loss.
dir_class_loss (dict): Config of direction classification loss.
dir_res_loss (dict): Config of direction residual regression loss.
size_res_loss (dict): Config of size residual regression loss.
corner_loss (dict): Config of bbox corners regression loss.
vote_loss (dict): Config of candidate points regression loss.
"""
def __init__(self,
num_classes,
bbox_coder,
in_channels=256,
train_cfg=None,
test_cfg=None,
vote_module_cfg=None,
vote_aggregation_cfg=None,
pred_layer_cfg=None,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'),
objectness_loss=None,
center_loss=None,
dir_class_loss=None,
dir_res_loss=None,
size_res_loss=None,
corner_loss=None,
vote_loss=None):
super(SSD3DHead, self).__init__(
num_classes,
bbox_coder,
train_cfg=train_cfg,
test_cfg=test_cfg,
vote_module_cfg=vote_module_cfg,
vote_aggregation_cfg=vote_aggregation_cfg,
pred_layer_cfg=pred_layer_cfg,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
objectness_loss=objectness_loss,
center_loss=center_loss,
dir_class_loss=dir_class_loss,
dir_res_loss=dir_res_loss,
size_class_loss=None,
size_res_loss=size_res_loss,
semantic_loss=None)
self.corner_loss = build_loss(corner_loss)
self.vote_loss = build_loss(vote_loss)
self.num_candidates = vote_module_cfg['num_points']
def _get_cls_out_channels(self):
"""Return the channel number of classification outputs."""
# Class numbers (k) + objectness (1)
return self.num_classes
def _get_reg_out_channels(self):
"""Return the channel number of regression outputs."""
# Bbox classification and regression
# (center residual (3), size regression (3)
# heading class+residual (num_dir_bins*2)),
return 3 + 3 + self.num_dir_bins * 2
def _extract_input(self, feat_dict):
"""Extract inputs from features dictionary.
Args:
feat_dict (dict): Feature dict from backbone.
Returns:
torch.Tensor: Coordinates of input points.
torch.Tensor: Features of input points.
torch.Tensor: Indices of input points.
"""
seed_points = feat_dict['sa_xyz'][-1]
seed_features = feat_dict['sa_features'][-1]
seed_indices = feat_dict['sa_indices'][-1]
return seed_points, seed_features, seed_indices
def loss(self,
bbox_preds,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
img_metas=None,
gt_bboxes_ignore=None):
"""Compute loss.
Args:
bbox_preds (dict): Predictions from forward of SSD3DHead.
points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
bboxes of each sample.
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
pts_semantic_mask (None | list[torch.Tensor]): Point-wise
semantic mask.
pts_instance_mask (None | list[torch.Tensor]): Point-wise
instance mask.
img_metas (list[dict]): Contain pcd and img's meta info.
gt_bboxes_ignore (None | list[torch.Tensor]): Specify
which bounding.
Returns:
dict: Losses of 3DSSD.
"""
targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask,
bbox_preds)
(vote_targets, center_targets, size_res_targets, dir_class_targets,
dir_res_targets, mask_targets, centerness_targets, corner3d_targets,
vote_mask, positive_mask, negative_mask, centerness_weights,
box_loss_weights, heading_res_loss_weight) = targets
# calculate centerness loss
centerness_loss = self.objectness_loss(
bbox_preds['obj_scores'].transpose(2, 1),
centerness_targets,
weight=centerness_weights)
# calculate center loss
center_loss = self.center_loss(
bbox_preds['center_offset'],
center_targets,
weight=box_loss_weights.unsqueeze(-1))
# calculate direction class loss
dir_class_loss = self.dir_class_loss(
bbox_preds['dir_class'].transpose(1, 2),
dir_class_targets,
weight=box_loss_weights)
# calculate direction residual loss
dir_res_loss = self.dir_res_loss(
bbox_preds['dir_res_norm'],
dir_res_targets.unsqueeze(-1).repeat(1, 1, self.num_dir_bins),
weight=heading_res_loss_weight)
# calculate size residual loss
size_loss = self.size_res_loss(
bbox_preds['size'],
size_res_targets,
weight=box_loss_weights.unsqueeze(-1))
# calculate corner loss
one_hot_dir_class_targets = dir_class_targets.new_zeros(
bbox_preds['dir_class'].shape)
one_hot_dir_class_targets.scatter_(2, dir_class_targets.unsqueeze(-1),
1)
pred_bbox3d = self.bbox_coder.decode(
dict(
center=bbox_preds['center'],
dir_res=bbox_preds['dir_res'],
dir_class=one_hot_dir_class_targets,
size=bbox_preds['size']))
pred_bbox3d = pred_bbox3d.reshape(-1, pred_bbox3d.shape[-1])
pred_bbox3d = img_metas[0]['box_type_3d'](
pred_bbox3d.clone(),
box_dim=pred_bbox3d.shape[-1],
with_yaw=self.bbox_coder.with_rot,
origin=(0.5, 0.5, 0.5))
pred_corners3d = pred_bbox3d.corners.reshape(-1, 8, 3)
corner_loss = self.corner_loss(
pred_corners3d,
corner3d_targets.reshape(-1, 8, 3),
weight=box_loss_weights.view(-1, 1, 1))
# calculate vote loss
vote_loss = self.vote_loss(
bbox_preds['vote_offset'].transpose(1, 2),
vote_targets,
weight=vote_mask.unsqueeze(-1))
losses = dict(
centerness_loss=centerness_loss,
center_loss=center_loss,
dir_class_loss=dir_class_loss,
dir_res_loss=dir_res_loss,
size_res_loss=size_loss,
corner_loss=corner_loss,
vote_loss=vote_loss)
return losses
def get_targets(self,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
bbox_preds=None):
"""Generate targets of ssd3d head.
Args:
points (list[torch.Tensor]): Points of each batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): Labels of each batch.
pts_semantic_mask (None | list[torch.Tensor]): Point-wise semantic
label of each batch.
pts_instance_mask (None | list[torch.Tensor]): Point-wise instance
label of each batch.
bbox_preds (torch.Tensor): Bounding box predictions of ssd3d head.
Returns:
tuple[torch.Tensor]: Targets of ssd3d head.
"""
# find empty example
for index in range(len(gt_labels_3d)):
if len(gt_labels_3d[index]) == 0:
fake_box = gt_bboxes_3d[index].tensor.new_zeros(
1, gt_bboxes_3d[index].tensor.shape[-1])
gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box)
gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1)
if pts_semantic_mask is None:
pts_semantic_mask = [None for i in range(len(gt_labels_3d))]
pts_instance_mask = [None for i in range(len(gt_labels_3d))]
aggregated_points = [
bbox_preds['aggregated_points'][i]
for i in range(len(gt_labels_3d))
]
seed_points = [
bbox_preds['seed_points'][i, :self.num_candidates].detach()
for i in range(len(gt_labels_3d))
]
(vote_targets, center_targets, size_res_targets, dir_class_targets,
dir_res_targets, mask_targets, centerness_targets, corner3d_targets,
vote_mask, positive_mask, negative_mask) = multi_apply(
self.get_targets_single, points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, aggregated_points,
seed_points)
center_targets = torch.stack(center_targets)
positive_mask = torch.stack(positive_mask)
negative_mask = torch.stack(negative_mask)
dir_class_targets = torch.stack(dir_class_targets)
dir_res_targets = torch.stack(dir_res_targets)
size_res_targets = torch.stack(size_res_targets)
mask_targets = torch.stack(mask_targets)
centerness_targets = torch.stack(centerness_targets).detach()
corner3d_targets = torch.stack(corner3d_targets)
vote_targets = torch.stack(vote_targets)
vote_mask = torch.stack(vote_mask)
center_targets -= bbox_preds['aggregated_points']
centerness_weights = (positive_mask +
negative_mask).unsqueeze(-1).repeat(
1, 1, self.num_classes).float()
centerness_weights = centerness_weights / \
(centerness_weights.sum() + 1e-6)
vote_mask = vote_mask / (vote_mask.sum() + 1e-6)
box_loss_weights = positive_mask / (positive_mask.sum() + 1e-6)
batch_size, proposal_num = dir_class_targets.shape[:2]
heading_label_one_hot = dir_class_targets.new_zeros(
(batch_size, proposal_num, self.num_dir_bins))
heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1)
heading_res_loss_weight = heading_label_one_hot * \
box_loss_weights.unsqueeze(-1)
return (vote_targets, center_targets, size_res_targets,
dir_class_targets, dir_res_targets, mask_targets,
centerness_targets, corner3d_targets, vote_mask, positive_mask,
negative_mask, centerness_weights, box_loss_weights,
heading_res_loss_weight)
def get_targets_single(self,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
aggregated_points=None,
seed_points=None):
"""Generate targets of ssd3d head for single batch.
Args:
points (torch.Tensor): Points of each batch.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth \
boxes of each batch.
gt_labels_3d (torch.Tensor): Labels of each batch.
pts_semantic_mask (None | torch.Tensor): Point-wise semantic
label of each batch.
pts_instance_mask (None | torch.Tensor): Point-wise instance
label of each batch.
aggregated_points (torch.Tensor): Aggregated points from
candidate points layer.
seed_points (torch.Tensor): Seed points of candidate points.
Returns:
tuple[torch.Tensor]: Targets of ssd3d head.
"""
assert self.bbox_coder.with_rot or pts_semantic_mask is not None
gt_bboxes_3d = gt_bboxes_3d.to(points.device)
valid_gt = gt_labels_3d != -1
gt_bboxes_3d = gt_bboxes_3d[valid_gt]
gt_labels_3d = gt_labels_3d[valid_gt]
gt_corner3d = gt_bboxes_3d.corners
(center_targets, size_targets, dir_class_targets,
dir_res_targets) = self.bbox_coder.encode(gt_bboxes_3d, gt_labels_3d)
points_mask, assignment = self._assign_targets_by_points_inside(
gt_bboxes_3d, aggregated_points)
center_targets = center_targets[assignment]
size_res_targets = size_targets[assignment]
mask_targets = gt_labels_3d[assignment]
dir_class_targets = dir_class_targets[assignment]
dir_res_targets = dir_res_targets[assignment]
corner3d_targets = gt_corner3d[assignment]
top_center_targets = center_targets.clone()
top_center_targets[:, 2] += size_res_targets[:, 2]
dist = torch.norm(aggregated_points - top_center_targets, dim=1)
dist_mask = dist < self.train_cfg.pos_distance_thr
positive_mask = (points_mask.max(1)[0] > 0) * dist_mask
negative_mask = (points_mask.max(1)[0] == 0)
# Centerness loss targets
canonical_xyz = aggregated_points - center_targets
if self.bbox_coder.with_rot:
# TODO: Align points rotation implementation of
# LiDARInstance3DBoxes and DepthInstance3DBoxes
canonical_xyz = rotation_3d_in_axis(
canonical_xyz.unsqueeze(0).transpose(0, 1),
-gt_bboxes_3d.yaw[assignment], 2).squeeze(1)
distance_front = torch.clamp(
size_res_targets[:, 0] - canonical_xyz[:, 0], min=0)
distance_back = torch.clamp(
size_res_targets[:, 0] + canonical_xyz[:, 0], min=0)
distance_left = torch.clamp(
size_res_targets[:, 1] - canonical_xyz[:, 1], min=0)
distance_right = torch.clamp(
size_res_targets[:, 1] + canonical_xyz[:, 1], min=0)
distance_top = torch.clamp(
size_res_targets[:, 2] - canonical_xyz[:, 2], min=0)
distance_bottom = torch.clamp(
size_res_targets[:, 2] + canonical_xyz[:, 2], min=0)
centerness_l = torch.min(distance_front, distance_back) / torch.max(
distance_front, distance_back)
centerness_w = torch.min(distance_left, distance_right) / torch.max(
distance_left, distance_right)
centerness_h = torch.min(distance_bottom, distance_top) / torch.max(
distance_bottom, distance_top)
centerness_targets = torch.clamp(
centerness_l * centerness_w * centerness_h, min=0)
centerness_targets = centerness_targets.pow(1 / 3.0)
centerness_targets = torch.clamp(centerness_targets, min=0, max=1)
proposal_num = centerness_targets.shape[0]
one_hot_centerness_targets = centerness_targets.new_zeros(
(proposal_num, self.num_classes))
one_hot_centerness_targets.scatter_(1, mask_targets.unsqueeze(-1), 1)
centerness_targets = centerness_targets.unsqueeze(
1) * one_hot_centerness_targets
# Vote loss targets
enlarged_gt_bboxes_3d = gt_bboxes_3d.enlarged_box(
self.train_cfg.expand_dims_length)
enlarged_gt_bboxes_3d.tensor[:, 2] -= self.train_cfg.expand_dims_length
vote_mask, vote_assignment = self._assign_targets_by_points_inside(
enlarged_gt_bboxes_3d, seed_points)
vote_targets = gt_bboxes_3d.gravity_center
vote_targets = vote_targets[vote_assignment] - seed_points
vote_mask = vote_mask.max(1)[0] > 0
return (vote_targets, center_targets, size_res_targets,
dir_class_targets, dir_res_targets, mask_targets,
centerness_targets, corner3d_targets, vote_mask, positive_mask,
negative_mask)
def get_bboxes(self, points, bbox_preds, input_metas, rescale=False):
"""Generate bboxes from sdd3d head predictions.
Args:
points (torch.Tensor): Input points.
bbox_preds (dict): Predictions from sdd3d head.
input_metas (list[dict]): Point cloud and image's meta info.
rescale (bool): Whether to rescale bboxes.
Returns:
list[tuple[torch.Tensor]]: Bounding boxes, scores and labels.
"""
# decode boxes
sem_scores = F.sigmoid(bbox_preds['obj_scores']).transpose(1, 2)
obj_scores = sem_scores.max(-1)[0]
bbox3d = self.bbox_coder.decode(bbox_preds)
batch_size = bbox3d.shape[0]
results = list()
for b in range(batch_size):
bbox_selected, score_selected, labels = self.multiclass_nms_single(
obj_scores[b], sem_scores[b], bbox3d[b], points[b, ..., :3],
input_metas[b])
bbox = input_metas[b]['box_type_3d'](
bbox_selected.clone(),
box_dim=bbox_selected.shape[-1],
with_yaw=self.bbox_coder.with_rot)
results.append((bbox, score_selected, labels))
return results
def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points,
input_meta):
"""Multi-class nms in single batch.
Args:
obj_scores (torch.Tensor): Objectness score of bounding boxes.
sem_scores (torch.Tensor): semantic class score of bounding boxes.
bbox (torch.Tensor): Predicted bounding boxes.
points (torch.Tensor): Input points.
input_meta (dict): Point cloud and image's meta info.
Returns:
tuple[torch.Tensor]: Bounding boxes, scores and labels.
"""
num_bbox = bbox.shape[0]
bbox = input_meta['box_type_3d'](
bbox.clone(),
box_dim=bbox.shape[-1],
with_yaw=self.bbox_coder.with_rot,
origin=(0.5, 0.5, 1.0))
if isinstance(bbox, LiDARInstance3DBoxes):
box_idx = bbox.points_in_boxes(points)
box_indices = box_idx.new_zeros([num_bbox + 1])
box_idx[box_idx == -1] = num_bbox
box_indices.scatter_add_(0, box_idx.long(),
box_idx.new_ones(box_idx.shape))
box_indices = box_indices[:-1]
nonempty_box_mask = box_indices >= 0
elif isinstance(bbox, DepthInstance3DBoxes):
box_indices = bbox.points_in_boxes(points)
nonempty_box_mask = box_indices.T.sum(1) >= 0
else:
raise NotImplementedError('Unsupported bbox type!')
corner3d = bbox.corners
minmax_box3d = corner3d.new(torch.Size((corner3d.shape[0], 6)))
minmax_box3d[:, :3] = torch.min(corner3d, dim=1)[0]
minmax_box3d[:, 3:] = torch.max(corner3d, dim=1)[0]
bbox_classes = torch.argmax(sem_scores, -1)
nms_selected = batched_nms(
minmax_box3d[nonempty_box_mask][:, [0, 1, 3, 4]],
obj_scores[nonempty_box_mask], bbox_classes[nonempty_box_mask],
self.test_cfg.nms_cfg)[1]
if nms_selected.shape[0] > self.test_cfg.max_output_num:
nms_selected = nms_selected[:self.test_cfg.max_output_num]
# filter empty boxes and boxes with low score
scores_mask = (obj_scores >= self.test_cfg.score_thr)
nonempty_box_inds = torch.nonzero(nonempty_box_mask).flatten()
nonempty_mask = torch.zeros_like(bbox_classes).scatter(
0, nonempty_box_inds[nms_selected], 1)
selected = (nonempty_mask.bool() & scores_mask.bool())
if self.test_cfg.per_class_proposal:
bbox_selected, score_selected, labels = [], [], []
for k in range(sem_scores.shape[-1]):
bbox_selected.append(bbox[selected].tensor)
score_selected.append(obj_scores[selected])
labels.append(
torch.zeros_like(bbox_classes[selected]).fill_(k))
bbox_selected = torch.cat(bbox_selected, 0)
score_selected = torch.cat(score_selected, 0)
labels = torch.cat(labels, 0)
else:
bbox_selected = bbox[selected].tensor
score_selected = obj_scores[selected]
labels = bbox_classes[selected]
return bbox_selected, score_selected, labels
def _assign_targets_by_points_inside(self, bboxes_3d, points):
"""Compute assignment by checking whether point is inside bbox.
Args:
bboxes_3d (BaseInstance3DBoxes): Instance of bounding boxes.
points (torch.Tensor): Points of a batch.
Returns:
tuple[torch.Tensor]: Flags indicating whether each point is
inside bbox and the index of box where each point are in.
"""
# TODO: align points_in_boxes function in each box_structures
num_bbox = bboxes_3d.tensor.shape[0]
if isinstance(bboxes_3d, LiDARInstance3DBoxes):
assignment = bboxes_3d.points_in_boxes(points).long()
points_mask = assignment.new_zeros(
[assignment.shape[0], num_bbox + 1])
assignment[assignment == -1] = num_bbox
points_mask.scatter_(1, assignment.unsqueeze(1), 1)
points_mask = points_mask[:, :-1]
assignment[assignment == num_bbox] = num_bbox - 1
elif isinstance(bboxes_3d, DepthInstance3DBoxes):
points_mask = bboxes_3d.points_in_boxes(points)
assignment = points_mask.argmax(dim=-1)
else:
raise NotImplementedError('Unsupported bbox type!')
return points_mask, assignment
import numpy as np
import torch
from mmcv.cnn import ConvModule
from torch import nn as nn
from torch.nn import functional as F
......@@ -11,6 +10,7 @@ from mmdet3d.models.model_utils import VoteModule
from mmdet3d.ops import build_sa_module, furthest_point_sample
from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS
from .base_conv_bbox_head import BaseConvBboxHead
@HEADS.register_module()
......@@ -23,10 +23,10 @@ class VoteHead(nn.Module):
decoding boxes.
train_cfg (dict): Config for training.
test_cfg (dict): Config for testing.
vote_moudule_cfg (dict): Config of VoteModule for point-wise votes.
vote_module_cfg (dict): Config of VoteModule for point-wise votes.
vote_aggregation_cfg (dict): Config of vote aggregation layer.
feat_channels (tuple[int]): Convolution channels of
prediction layer.
pred_layer_cfg (dict): Config of classfication and regression
prediction layers.
conv_cfg (dict): Config of convolution in prediction layer.
norm_cfg (dict): Config of BN in prediction layer.
objectness_loss (dict): Config of objectness loss.
......@@ -43,9 +43,9 @@ class VoteHead(nn.Module):
bbox_coder,
train_cfg=None,
test_cfg=None,
vote_moudule_cfg=None,
vote_module_cfg=None,
vote_aggregation_cfg=None,
feat_channels=(128, 128),
pred_layer_cfg=None,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=None,
......@@ -59,54 +59,64 @@ class VoteHead(nn.Module):
self.num_classes = num_classes
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.gt_per_seed = vote_moudule_cfg['gt_per_seed']
self.gt_per_seed = vote_module_cfg['gt_per_seed']
self.num_proposal = vote_aggregation_cfg['num_point']
self.objectness_loss = build_loss(objectness_loss)
self.center_loss = build_loss(center_loss)
self.dir_class_loss = build_loss(dir_class_loss)
self.dir_res_loss = build_loss(dir_res_loss)
self.size_class_loss = build_loss(size_class_loss)
self.dir_class_loss = build_loss(dir_class_loss)
self.size_res_loss = build_loss(size_res_loss)
self.semantic_loss = build_loss(semantic_loss)
assert vote_aggregation_cfg['mlp_channels'][0] == vote_moudule_cfg[
'in_channels']
if size_class_loss is not None:
self.size_class_loss = build_loss(size_class_loss)
if semantic_loss is not None:
self.semantic_loss = build_loss(semantic_loss)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_sizes = self.bbox_coder.num_sizes
self.num_dir_bins = self.bbox_coder.num_dir_bins
self.vote_module = VoteModule(**vote_moudule_cfg)
self.vote_module = VoteModule(**vote_module_cfg)
self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
prev_channel = vote_aggregation_cfg['mlp_channels'][-1]
conv_pred_list = list()
for k in range(len(feat_channels)):
conv_pred_list.append(
ConvModule(
prev_channel,
feat_channels[k],
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
inplace=True))
prev_channel = feat_channels[k]
self.conv_pred = nn.Sequential(*conv_pred_list)
# Bbox classification and regression
self.conv_pred = BaseConvBboxHead(
**pred_layer_cfg,
num_cls_out_channels=self._get_cls_out_channels(),
num_reg_out_channels=self._get_reg_out_channels())
def init_weights(self):
"""Initialize weights of VoteHead."""
pass
def _get_cls_out_channels(self):
"""Return the channel number of classification outputs."""
# Class numbers (k) + objectness (2)
return self.num_classes + 2
def _get_reg_out_channels(self):
"""Return the channel number of regression outputs."""
# Objectness scores (2), center residual (3),
# heading class+residual (num_dir_bins*2),
# size class+residual(num_sizes*4)
conv_out_channel = (2 + 3 + self.num_dir_bins * 2 +
self.num_sizes * 4 + num_classes)
self.conv_pred.add_module('conv_out',
nn.Conv1d(prev_channel, conv_out_channel, 1))
return 3 + self.num_dir_bins * 2 + self.num_sizes * 4
def init_weights(self):
"""Initialize weights of VoteHead."""
pass
def _extract_input(self, feat_dict):
"""Extract inputs from features dictionary.
Args:
feat_dict (dict): Feature dict from backbone.
Returns:
torch.Tensor: Coordinates of input points.
torch.Tensor: Features of input points.
torch.Tensor: Indices of input points.
"""
seed_points = feat_dict['fp_xyz'][-1]
seed_features = feat_dict['fp_features'][-1]
seed_indices = feat_dict['fp_indices'][-1]
return seed_points, seed_features, seed_indices
def forward(self, feat_dict, sample_mod):
"""Forward pass.
......@@ -122,57 +132,74 @@ class VoteHead(nn.Module):
Args:
feat_dict (dict): Feature dict from backbone.
sample_mod (str): Sample mode for vote aggregation layer.
valid modes are "vote", "seed" and "random".
valid modes are "vote", "seed", "random" and "spec".
Returns:
dict: Predictions of vote head.
"""
assert sample_mod in ['vote', 'seed', 'random']
assert sample_mod in ['vote', 'seed', 'random', 'spec']
seed_points = feat_dict['fp_xyz'][-1]
seed_features = feat_dict['fp_features'][-1]
seed_indices = feat_dict['fp_indices'][-1]
seed_points, seed_features, seed_indices = self._extract_input(
feat_dict)
# 1. generate vote_points from seed_points
vote_points, vote_features = self.vote_module(seed_points,
seed_features)
vote_points, vote_features, vote_offset = self.vote_module(
seed_points, seed_features)
results = dict(
seed_points=seed_points,
seed_indices=seed_indices,
vote_points=vote_points,
vote_features=vote_features)
vote_features=vote_features,
vote_offset=vote_offset)
# 2. aggregate vote_points
if sample_mod == 'vote':
# use fps in vote_aggregation
sample_indices = None
aggregation_inputs = dict(
points_xyz=vote_points, features=vote_features)
elif sample_mod == 'seed':
# FPS on seed and choose the votes corresponding to the seeds
sample_indices = furthest_point_sample(seed_points,
self.num_proposal)
aggregation_inputs = dict(
points_xyz=vote_points,
features=vote_features,
indices=sample_indices)
elif sample_mod == 'random':
# Random sampling from the votes
batch_size, num_seed = seed_points.shape[:2]
sample_indices = seed_points.new_tensor(
torch.randint(0, num_seed, (batch_size, self.num_proposal)),
dtype=torch.int32)
aggregation_inputs = dict(
points_xyz=vote_points,
features=vote_features,
indices=sample_indices)
elif sample_mod == 'spec':
# Specify the new center in vote_aggregation
aggregation_inputs = dict(
points_xyz=seed_points,
features=seed_features,
target_xyz=vote_points)
else:
raise NotImplementedError(
f'Sample mode {sample_mod} is not supported!')
vote_aggregation_ret = self.vote_aggregation(vote_points,
vote_features,
sample_indices)
vote_aggregation_ret = self.vote_aggregation(**aggregation_inputs)
aggregated_points, features, aggregated_indices = vote_aggregation_ret
results['aggregated_points'] = aggregated_points
results['aggregated_features'] = features
results['aggregated_indices'] = aggregated_indices
# 3. predict bbox and score
predictions = self.conv_pred(features)
cls_predictions, reg_predictions = self.conv_pred(features)
# 4. decode predictions
decode_res = self.bbox_coder.split_pred(predictions, aggregated_points)
decode_res = self.bbox_coder.split_pred(cls_predictions,
reg_predictions,
aggregated_points)
results.update(decode_res)
return results
......
......@@ -4,10 +4,12 @@ from .h3dnet import H3DNet
from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN
from .mvx_two_stage import MVXTwoStageDetector
from .parta2 import PartA2
from .ssd3dnet import SSD3DNet
from .votenet import VoteNet
from .voxelnet import VoxelNet
__all__ = [
'Base3DDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXTwoStageDetector',
'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet'
'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet',
'SSD3DNet'
]
from mmdet.models import DETECTORS
from .votenet import VoteNet
@DETECTORS.register_module()
class SSD3DNet(VoteNet):
"""3DSSDNet model.
https://arxiv.org/abs/2002.10187.pdf
"""
def __init__(self,
backbone,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(SSD3DNet, self).__init__(
backbone=backbone,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
import torch
from mmcv import is_tuple_of
from mmcv.cnn import ConvModule
from torch import nn as nn
......@@ -15,6 +16,7 @@ class VoteModule(nn.Module):
vote_per_seed (int): Number of votes generated from each seed point.
gt_per_seed (int): Number of ground truth votes generated
from each seed point.
num_points (int): Number of points to be used for voting.
conv_channels (tuple[int]): Out channels of vote
generating convolution.
conv_cfg (dict): Config of convolution.
......@@ -23,6 +25,9 @@ class VoteModule(nn.Module):
Default: dict(type='BN1d').
norm_feats (bool): Whether to normalize features.
Default: True.
with_res_feat (bool): Whether to predict residual features.
Default: True.
vote_xyz_range (list[float], None): The range of points translation.
vote_loss (dict): Config of vote loss.
"""
......@@ -30,17 +35,28 @@ class VoteModule(nn.Module):
in_channels,
vote_per_seed=1,
gt_per_seed=3,
num_points=-1,
conv_channels=(16, 16),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'),
norm_feats=True,
with_res_feat=True,
vote_xyz_range=None,
vote_loss=None):
super().__init__()
self.in_channels = in_channels
self.vote_per_seed = vote_per_seed
self.gt_per_seed = gt_per_seed
self.num_points = num_points
self.norm_feats = norm_feats
self.vote_loss = build_loss(vote_loss)
self.with_res_feat = with_res_feat
assert vote_xyz_range is None or is_tuple_of(vote_xyz_range, float)
self.vote_xyz_range = vote_xyz_range
if vote_loss is not None:
self.vote_loss = build_loss(vote_loss)
prev_channels = in_channels
vote_conv_list = list()
......@@ -53,13 +69,17 @@ class VoteModule(nn.Module):
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
bias=True,
inplace=True))
prev_channels = conv_channels[k]
self.vote_conv = nn.Sequential(*vote_conv_list)
# conv_out predicts coordinate and residual features
out_channel = (3 + in_channels) * self.vote_per_seed
if with_res_feat:
out_channel = (3 + in_channels) * self.vote_per_seed
else:
out_channel = 3 * self.vote_per_seed
self.conv_out = nn.Conv1d(prev_channels, out_channel, 1)
def forward(self, seed_points, seed_feats):
......@@ -80,6 +100,13 @@ class VoteModule(nn.Module):
shape (B, C, M) where ``M=num_seed*vote_per_seed``, \
``C=vote_feature_dim``.
"""
if self.num_points != -1:
assert self.num_points < seed_points.shape[1], \
f'Number of vote points ({self.num_points}) should be '\
f'smaller than seed points size ({seed_points.shape[1]})'
seed_points = seed_points[:, :self.num_points]
seed_feats = seed_feats[..., :self.num_points]
batch_size, feat_channels, num_seed = seed_feats.shape
num_vote = num_seed * self.vote_per_seed
x = self.vote_conv(seed_feats)
......@@ -88,21 +115,36 @@ class VoteModule(nn.Module):
votes = votes.transpose(2, 1).view(batch_size, num_seed,
self.vote_per_seed, -1)
offset = votes[:, :, :, 0:3].contiguous()
res_feats = votes[:, :, :, 3:].contiguous()
vote_points = seed_points.unsqueeze(2) + offset
offset = votes[:, :, :, 0:3]
if self.vote_xyz_range is not None:
limited_offset_list = []
for axis in range(len(self.vote_xyz_range)):
limited_offset_list.append(offset[..., axis].clamp(
min=-self.vote_xyz_range[axis],
max=self.vote_xyz_range[axis]))
limited_offset = torch.stack(limited_offset_list, -1)
vote_points = (seed_points.unsqueeze(2) +
limited_offset).contiguous()
else:
vote_points = (seed_points.unsqueeze(2) + offset).contiguous()
vote_points = vote_points.view(batch_size, num_vote, 3)
vote_feats = seed_feats.permute(
0, 2, 1).unsqueeze(2).contiguous() + res_feats
vote_feats = vote_feats.view(batch_size, num_vote,
feat_channels).transpose(2,
1).contiguous()
if self.norm_feats:
features_norm = torch.norm(vote_feats, p=2, dim=1)
vote_feats = vote_feats.div(features_norm.unsqueeze(1))
return vote_points, vote_feats
offset = offset.reshape(batch_size, num_vote, 3).transpose(2, 1)
if self.with_res_feat:
res_feats = votes[:, :, :, 3:]
vote_feats = (seed_feats.transpose(2, 1).unsqueeze(2) +
res_feats).contiguous()
vote_feats = vote_feats.view(batch_size,
num_vote, feat_channels).transpose(
2, 1).contiguous()
if self.norm_feats:
features_norm = torch.norm(vote_feats, p=2, dim=1)
vote_feats = vote_feats.div(features_norm.unsqueeze(1))
else:
vote_feats = seed_feats
return vote_points, vote_feats, offset
def get_loss(self, seed_points, vote_points, seed_indices,
vote_targets_mask, vote_targets):
......
......@@ -308,8 +308,9 @@ class H3DBboxHead(nn.Module):
for conv_module in self.bbox_pred[1:]:
bbox_predictions = conv_module(bbox_predictions)
refine_decode_res = self.bbox_coder.split_pred(bbox_predictions,
aggregated_points)
refine_decode_res = self.bbox_coder.split_pred(
bbox_predictions[:, :self.num_classes + 2],
bbox_predictions[:, self.num_classes + 2:], aggregated_points)
for key in refine_decode_res.keys():
ret_dict[key + '_optimized'] = refine_decode_res[key]
return ret_dict
......
......@@ -23,7 +23,7 @@ class PrimitiveHead(nn.Module):
decoding boxes.
train_cfg (dict): Config for training.
test_cfg (dict): Config for testing.
vote_moudule_cfg (dict): Config of VoteModule for point-wise votes.
vote_module_cfg (dict): Config of VoteModule for point-wise votes.
vote_aggregation_cfg (dict): Config of vote aggregation layer.
feat_channels (tuple[int]): Convolution channels of
prediction layer.
......@@ -42,7 +42,7 @@ class PrimitiveHead(nn.Module):
primitive_mode,
train_cfg=None,
test_cfg=None,
vote_moudule_cfg=None,
vote_module_cfg=None,
vote_aggregation_cfg=None,
feat_channels=(128, 128),
upper_thresh=100.0,
......@@ -61,7 +61,7 @@ class PrimitiveHead(nn.Module):
self.primitive_mode = primitive_mode
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.gt_per_seed = vote_moudule_cfg['gt_per_seed']
self.gt_per_seed = vote_module_cfg['gt_per_seed']
self.num_proposal = vote_aggregation_cfg['num_point']
self.upper_thresh = upper_thresh
self.surface_thresh = surface_thresh
......@@ -71,13 +71,13 @@ class PrimitiveHead(nn.Module):
self.semantic_reg_loss = build_loss(semantic_reg_loss)
self.semantic_cls_loss = build_loss(semantic_cls_loss)
assert vote_aggregation_cfg['mlp_channels'][0] == vote_moudule_cfg[
assert vote_aggregation_cfg['mlp_channels'][0] == vote_module_cfg[
'in_channels']
# Primitive existence flag prediction
self.flag_conv = ConvModule(
vote_moudule_cfg['conv_channels'][-1],
vote_moudule_cfg['conv_channels'][-1] // 2,
vote_module_cfg['conv_channels'][-1],
vote_module_cfg['conv_channels'][-1] // 2,
1,
padding=0,
conv_cfg=conv_cfg,
......@@ -85,9 +85,9 @@ class PrimitiveHead(nn.Module):
bias=True,
inplace=True)
self.flag_pred = torch.nn.Conv1d(
vote_moudule_cfg['conv_channels'][-1] // 2, 2, 1)
vote_module_cfg['conv_channels'][-1] // 2, 2, 1)
self.vote_module = VoteModule(**vote_moudule_cfg)
self.vote_module = VoteModule(**vote_module_cfg)
self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
prev_channel = vote_aggregation_cfg['mlp_channels'][-1]
......@@ -137,8 +137,8 @@ class PrimitiveHead(nn.Module):
results['pred_flag_' + self.primitive_mode] = primitive_flag
# 1. generate vote_points from seed_points
vote_points, vote_features = self.vote_module(seed_points,
seed_features)
vote_points, vote_features, _ = self.vote_module(
seed_points, seed_features)
results['vote_' + self.primitive_mode] = vote_points
results['vote_features_' + self.primitive_mode] = vote_features
......
import torch
from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.core.bbox import DepthInstance3DBoxes, LiDARInstance3DBoxes
from mmdet.core import build_bbox_coder
......@@ -194,9 +194,10 @@ def test_partial_bin_based_box_coder():
assert torch.allclose(bbox3d, expected_bbox3d, atol=1e-4)
# test split_pred
box_preds = torch.rand(2, 79, 256)
cls_preds = torch.rand(2, 12, 256)
reg_preds = torch.rand(2, 67, 256)
base_xyz = torch.rand(2, 256, 3)
results = box_coder.split_pred(box_preds, base_xyz)
results = box_coder.split_pred(cls_preds, reg_preds, base_xyz)
obj_scores = results['obj_scores']
center = results['center']
dir_class = results['dir_class']
......@@ -215,3 +216,110 @@ def test_partial_bin_based_box_coder():
assert size_res_norm.shape == torch.Size([2, 256, 10, 3])
assert size_res.shape == torch.Size([2, 256, 10, 3])
assert sem_scores.shape == torch.Size([2, 256, 10])
def test_anchor_free_box_coder():
box_coder_cfg = dict(
type='AnchorFreeBBoxCoder', num_dir_bins=12, with_rot=True)
box_coder = build_bbox_coder(box_coder_cfg)
# test encode
gt_bboxes = LiDARInstance3DBoxes([[
2.1227e+00, 5.7951e+00, -9.9900e-01, 1.6736e+00, 4.2419e+00,
1.5473e+00, -1.5501e+00
],
[
1.1791e+01, 9.0276e+00, -8.5772e-01,
1.6210e+00, 3.5367e+00, 1.4841e+00,
-1.7369e+00
],
[
2.3638e+01, 9.6997e+00, -5.6713e-01,
1.7578e+00, 4.6103e+00, 1.5999e+00,
-1.4556e+00
]])
gt_labels = torch.tensor([0, 0, 0])
(center_targets, size_targets, dir_class_targets,
dir_res_targets) = box_coder.encode(gt_bboxes, gt_labels)
expected_center_target = torch.tensor([[2.1227, 5.7951, -0.2253],
[11.7908, 9.0276, -0.1156],
[23.6380, 9.6997, 0.2328]])
expected_size_targets = torch.tensor([[0.8368, 2.1210, 0.7736],
[0.8105, 1.7683, 0.7421],
[0.8789, 2.3052, 0.8000]])
expected_dir_class_target = torch.tensor([9, 9, 9])
expected_dir_res_target = torch.tensor([0.0394, -0.3172, 0.2199])
assert torch.allclose(center_targets, expected_center_target, atol=1e-4)
assert torch.allclose(size_targets, expected_size_targets, atol=1e-4)
assert torch.all(dir_class_targets == expected_dir_class_target)
assert torch.allclose(dir_res_targets, expected_dir_res_target, atol=1e-3)
# test decode
center = torch.tensor([[[14.5954, 6.3312, 0.7671],
[67.5245, 22.4422, 1.5610],
[47.7693, -6.7980, 1.4395]]])
size_res = torch.tensor([[[-1.0752, 1.8760, 0.7715],
[-0.8016, 1.1754, 0.0102],
[-1.2789, 0.5948, 0.4728]]])
dir_class = torch.tensor([[[
0.1512, 1.7914, -1.7658, 2.1572, -0.9215, 1.2139, 0.1749, 0.8606,
1.1743, -0.7679, -1.6005, 0.4623
],
[
-0.3957, 1.2026, -1.2677, 1.3863, -0.5754,
1.7083, 0.2601, 0.1129, 0.7146, -0.1367,
-1.2892, -0.0083
],
[
-0.8862, 1.2050, -1.3881, 1.6604, -0.9087,
1.1907, -0.0280, 0.2027, 1.0644, -0.7205,
-1.0738, 0.4748
]]])
dir_res = torch.tensor([[[
1.1151, 0.5535, -0.2053, -0.6582, -0.1616, -0.1821, 0.4675, 0.6621,
0.8146, -0.0448, -0.7253, -0.7171
],
[
0.7888, 0.2478, -0.1962, -0.7267, 0.0573,
-0.2398, 0.6984, 0.5859, 0.7507, -0.1980,
-0.6538, -0.6602
],
[
0.9039, 0.6109, 0.1960, -0.5016, 0.0551,
-0.4086, 0.3398, 0.2759, 0.7247, -0.0655,
-0.5052, -0.9026
]]])
bbox_out = dict(
center=center, size=size_res, dir_class=dir_class, dir_res=dir_res)
bbox3d = box_coder.decode(bbox_out)
expected_bbox3d = torch.tensor(
[[[14.5954, 6.3312, 0.7671, 0.1000, 3.7521, 1.5429, 0.9126],
[67.5245, 22.4422, 1.5610, 0.1000, 2.3508, 0.1000, 2.3782],
[47.7693, -6.7980, 1.4395, 0.1000, 1.1897, 0.9456, 1.0692]]])
assert torch.allclose(bbox3d, expected_bbox3d, atol=1e-4)
# test split_pred
cls_preds = torch.rand(2, 1, 256)
reg_preds = torch.rand(2, 30, 256)
base_xyz = torch.rand(2, 256, 3)
results = box_coder.split_pred(cls_preds, reg_preds, base_xyz)
obj_scores = results['obj_scores']
center = results['center']
center_offset = results['center_offset']
dir_class = results['dir_class']
dir_res_norm = results['dir_res_norm']
dir_res = results['dir_res']
size = results['size']
assert obj_scores.shape == torch.Size([2, 1, 256])
assert center.shape == torch.Size([2, 256, 3])
assert center_offset.shape == torch.Size([2, 256, 3])
assert dir_class.shape == torch.Size([2, 256, 12])
assert dir_res_norm.shape == torch.Size([2, 256, 12])
assert dir_res.shape == torch.Size([2, 256, 12])
assert size.shape == torch.Size([2, 256, 3])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment