Unverified Commit d07bd8bd authored by yinchimaoliang's avatar yinchimaoliang Committed by GitHub
Browse files

[feature]: Add multi-group head of CenterPoint (#49)

* Add modules.

* Add test_center_head.

* Add docstring.

* Change comments.

* Add dcn_head.

* Add doc_string.

* Add get_targets.

* Can use_get_targets.

* get_targets results aligned.

* Use box_structure.

* Use get_targets_single.

* Add docstring.

* Fix dcn_center_head unittest.

* Delete unnecessary unittest.

* Add docstring.

* Change format.

* Add circle_nms.

* Change structure of mg_head.

* Add bbox coder for centerpoint.

* Add docstrings.

* Add docstrings.

* Add get_bboxes and unittest.

* Change docstring.

* Add img_metas.

* Change bbox coder unittest.

* Add task_detections.

* Change docstring.

* Change circle_nms to cpu.

* Change test_nms.

* Change score_th, chang keep to long type.

* Change docstring and unittest.

* Remove unnecessary things.

* Move gaussian.

* move clip_sigmoid, change dict.

* Change config.

* Change test_heads.

* Move weight initialization to init_weights func.

* Remove loc_loss_element adn==nd num_postive.

* Change bboxes to the right format.

* Change loss and bbox order.

* Update test_heads.

* Change loss.

* Change names in mg_head, change head unittest.

* Remove centerpoint_focal_loss, change docstring.

* Change topK default to 80.

* Change boxes in test_nms. Change task_boxes defaults to None.

* Fix rotate nms bug.

* Change docstring.

* Add docstring for get_task_detection and loss.

* Remove gaussian funcs, change mg_head.

* Change gaussianfocalloss to mean.

* change centerpoint_bbox_coder '/' to torch.div, fix centerhead unittest.

* Change div to '/'

* Change order in centerpoint_coder, change names, change dcn layer.

* Fix import in __init__

* Add gaussian unittest.

* Remove np ops in mg_head.

* Update docstring.

* Fix docstring use config to build head.

* Remove **kwargs

* Remove unnecessary codes, change order of bboxes.

* Remove '\' in args and pdb, change loss_bbox.

* Fix test_heads unittest.

* Remove unnecessary comments

* Change bbox order in rotate nms.

* Remove unnecessary attributes

* Change name, remove float
parent 79a8299c
...@@ -2,5 +2,6 @@ from .anchor import * # noqa: F401, F403 ...@@ -2,5 +2,6 @@ from .anchor import * # noqa: F401, F403
from .bbox import * # noqa: F401, F403 from .bbox import * # noqa: F401, F403
from .evaluation import * # noqa: F401, F403 from .evaluation import * # noqa: F401, F403
from .post_processing import * # noqa: F401, F403 from .post_processing import * # noqa: F401, F403
from .utils import * # noqa: F401, F403
from .visualizer import * # noqa: F401, F403 from .visualizer import * # noqa: F401, F403
from .voxel import * # noqa: F401, F403 from .voxel import * # noqa: F401, F403
from mmdet.core.bbox import build_bbox_coder from mmdet.core.bbox import build_bbox_coder
from .anchor_free_bbox_coder import AnchorFreeBBoxCoder from .anchor_free_bbox_coder import AnchorFreeBBoxCoder
from .centerpoint_bbox_coders import CenterPointBBoxCoder
from .delta_xyzwhlr_bbox_coder import DeltaXYZWLHRBBoxCoder from .delta_xyzwhlr_bbox_coder import DeltaXYZWLHRBBoxCoder
from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder
__all__ = [ __all__ = [
'build_bbox_coder', 'DeltaXYZWLHRBBoxCoder', 'PartialBinBasedBBoxCoder', 'build_bbox_coder', 'DeltaXYZWLHRBBoxCoder', 'PartialBinBasedBBoxCoder',
'AnchorFreeBBoxCoder' 'CenterPointBBoxCoder', 'AnchorFreeBBoxCoder'
] ]
import torch
from mmdet.core.bbox import BaseBBoxCoder
from mmdet.core.bbox.builder import BBOX_CODERS
@BBOX_CODERS.register_module()
class CenterPointBBoxCoder(BaseBBoxCoder):
"""Bbox coder for CenterPoint.
Args:
pc_range (list[float]): Range of point cloud.
out_size_factor (int): Downsample factor of the model.
voxel_size (list[float]): Size of voxel.
post_center_range (list[float]): Limit of the center.
Default: None.
max_num (int): Max number to be kept. Default: 100.
score_threshold (float): Threshold to filter boxes based on score.
Default: None.
code_size (int): Code size of bboxes. Default: 9
"""
def __init__(self,
pc_range,
out_size_factor,
voxel_size,
post_center_range=None,
max_num=100,
score_threshold=None,
code_size=9):
self.pc_range = pc_range
self.out_size_factor = out_size_factor
self.voxel_size = voxel_size
self.post_center_range = post_center_range
self.max_num = max_num
self.score_threshold = score_threshold
self.code_size = code_size
def _gather_feat(self, feats, inds, feat_masks=None):
"""Given feats and indexes, returns the gathered feats.
Args:
feats (torch.Tensor): Features to be transposed and gathered
with the shape of [B, 2, W, H].
inds (torch.Tensor): Indexes with the shape of [B, N].
feat_masks (torch.Tensor): Mask of the feats. Default: None.
Returns:
torch.Tensor: Gathered feats.
"""
dim = feats.size(2)
inds = inds.unsqueeze(2).expand(inds.size(0), inds.size(1), dim)
feats = feats.gather(1, inds)
if feat_masks is not None:
feat_masks = feat_masks.unsqueeze(2).expand_as(feats)
feats = feats[feat_masks]
feats = feats.view(-1, dim)
return feats
def _topk(self, scores, K=80):
"""Get indexes based on scores.
Args:
scores (torch.Tensor): scores with the shape of [B, N, W, H].
K (int): Number to be kept. Defaults to 80.
Returns:
tuple[torch.Tensor]
torch.Tensor: Selected scores with the shape of [B, K].
torch.Tensor: Selected indexes with the shape of [B, K].
torch.Tensor: Selected classes with the shape of [B, K].
torch.Tensor: Selected y coord with the shape of [B, K].
torch.Tensor: Selected x coord with the shape of [B, K].
"""
batch, cat, height, width = scores.size()
topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)
topk_inds = topk_inds % (height * width)
topk_ys = (topk_inds.float() /
torch.tensor(width, dtype=torch.float)).int().float()
topk_xs = (topk_inds % width).int().float()
topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
topk_clses = (topk_ind / torch.tensor(K, dtype=torch.float)).int()
topk_inds = self._gather_feat(topk_inds.view(batch, -1, 1),
topk_ind).view(batch, K)
topk_ys = self._gather_feat(topk_ys.view(batch, -1, 1),
topk_ind).view(batch, K)
topk_xs = self._gather_feat(topk_xs.view(batch, -1, 1),
topk_ind).view(batch, K)
return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
def _transpose_and_gather_feat(self, feat, ind):
"""Given feats and indexes, returns the transposed and gathered feats.
Args:
feat (torch.Tensor): Features to be transposed and gathered
with the shape of [B, 2, W, H].
ind (torch.Tensor): Indexes with the shape of [B, N].
Returns:
torch.Tensor: Transposed and gathered feats.
"""
feat = feat.permute(0, 2, 3, 1).contiguous()
feat = feat.view(feat.size(0), -1, feat.size(3))
feat = self._gather_feat(feat, ind)
return feat
def encode(self):
pass
def decode(self,
heat,
rot_sine,
rot_cosine,
hei,
dim,
vel,
reg=None,
task_id=-1):
"""Decode bboxes.
Args:
heat (torch.Tensor): Heatmap with the shape of [B, N, W, H].
rot_sine (torch.Tensor): Sine of rotation with the shape of
[B, 1, W, H].
rot_cosine (torch.Tensor): Cosine of rotation with the shape of
[B, 1, W, H].
hei (torch.Tensor): Height of the boxes with the shape
of [B, 1, W, H].
dim (torch.Tensor): Dim of the boxes with the shape of
[B, 1, W, H].
vel (torch.Tensor): Velocity with the shape of [B, 1, W, H].
reg (torch.Tensor): Regression value of the boxes in 2D with
the shape of [B, 2, W, H]. Default: None.
task_id (int): Index of task. Default: -1.
Returns:
list[dict]: Decoded boxes.
"""
batch, cat, _, _ = heat.size()
scores, inds, clses, ys, xs = self._topk(heat, K=self.max_num)
if reg is not None:
reg = self._transpose_and_gather_feat(reg, inds)
reg = reg.view(batch, self.max_num, 2)
xs = xs.view(batch, self.max_num, 1) + reg[:, :, 0:1]
ys = ys.view(batch, self.max_num, 1) + reg[:, :, 1:2]
else:
xs = xs.view(batch, self.max_num, 1) + 0.5
ys = ys.view(batch, self.max_num, 1) + 0.5
# rotation value and direction label
rot_sine = self._transpose_and_gather_feat(rot_sine, inds)
rot_sine = rot_sine.view(batch, self.max_num, 1)
rot_cosine = self._transpose_and_gather_feat(rot_cosine, inds)
rot_cosine = rot_cosine.view(batch, self.max_num, 1)
rot = torch.atan2(rot_sine, rot_cosine)
# height in the bev
hei = self._transpose_and_gather_feat(hei, inds)
hei = hei.view(batch, self.max_num, 1)
# dim of the box
dim = self._transpose_and_gather_feat(dim, inds)
dim = dim.view(batch, self.max_num, 3)
# class label
clses = clses.view(batch, self.max_num).float()
scores = scores.view(batch, self.max_num)
xs = xs.view(
batch, self.max_num,
1) * self.out_size_factor * self.voxel_size[0] + self.pc_range[0]
ys = ys.view(
batch, self.max_num,
1) * self.out_size_factor * self.voxel_size[1] + self.pc_range[1]
if vel is None: # KITTI FORMAT
final_box_preds = torch.cat([xs, ys, hei, dim, rot], dim=2)
else: # exist velocity, nuscene format
vel = self._transpose_and_gather_feat(vel, inds)
vel = vel.view(batch, self.max_num, 2)
final_box_preds = torch.cat([xs, ys, hei, dim, rot, vel], dim=2)
final_scores = scores
final_preds = clses
# use score threshold
if self.score_threshold is not None:
thresh_mask = final_scores > self.score_threshold
if self.post_center_range is not None:
self.post_center_range = torch.tensor(
self.post_center_range, device=heat.device)
mask = (final_box_preds[..., :3] >=
self.post_center_range[:3]).all(2)
mask &= (final_box_preds[..., :3] <=
self.post_center_range[3:]).all(2)
predictions_dicts = []
for i in range(batch):
cmask = mask[i, :]
if self.score_threshold:
cmask &= thresh_mask[i]
boxes3d = final_box_preds[i, cmask]
scores = final_scores[i, cmask]
labels = final_preds[i, cmask]
predictions_dict = {
'bboxes': boxes3d,
'scores': scores,
'labels': labels
}
predictions_dicts.append(predictions_dict)
else:
raise NotImplementedError(
'Need to reorganize output as a batch, only '
'support post_center_range is not None for now!')
return predictions_dicts
from mmdet.core.post_processing import (merge_aug_bboxes, merge_aug_masks, from mmdet.core.post_processing import (merge_aug_bboxes, merge_aug_masks,
merge_aug_proposals, merge_aug_scores, merge_aug_proposals, merge_aug_scores,
multiclass_nms) multiclass_nms)
from .box3d_nms import aligned_3d_nms, box3d_multiclass_nms from .box3d_nms import aligned_3d_nms, box3d_multiclass_nms, circle_nms
from .merge_augs import merge_aug_bboxes_3d from .merge_augs import merge_aug_bboxes_3d
__all__ = [ __all__ = [
'multiclass_nms', 'merge_aug_proposals', 'merge_aug_bboxes', 'multiclass_nms', 'merge_aug_proposals', 'merge_aug_bboxes',
'merge_aug_scores', 'merge_aug_masks', 'box3d_multiclass_nms', 'merge_aug_scores', 'merge_aug_masks', 'box3d_multiclass_nms',
'aligned_3d_nms', 'merge_aug_bboxes_3d' 'aligned_3d_nms', 'merge_aug_bboxes_3d', 'circle_nms'
] ]
import numba
import numpy as np
import torch import torch
from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu
...@@ -134,3 +136,46 @@ def aligned_3d_nms(boxes, scores, classes, thresh): ...@@ -134,3 +136,46 @@ def aligned_3d_nms(boxes, scores, classes, thresh):
indices = boxes.new_tensor(pick, dtype=torch.long) indices = boxes.new_tensor(pick, dtype=torch.long)
return indices return indices
@numba.jit(nopython=True)
def circle_nms(dets, thresh, post_max_size=83):
"""Circular NMS.
An object is only counted as positive if no other center
with a higher confidence exists within a radius r using a
bird-eye view distance metric.
Args:
dets (torch.Tensor): Detection results with the shape of [N, 3].
thresh (float): Value of threshold.
post_max_size (int): Max number of prediction to be kept. Defaults
to 83
Returns:
torch.Tensor: Indexes of the detections to be kept.
"""
x1 = dets[:, 0]
y1 = dets[:, 1]
scores = dets[:, 2]
order = scores.argsort()[::-1].astype(np.int32) # highest->lowest
ndets = dets.shape[0]
suppressed = np.zeros((ndets), dtype=np.int32)
keep = []
for _i in range(ndets):
i = order[_i] # start with highest score box
if suppressed[
i] == 1: # if any box have enough iou with this, remove it
continue
keep.append(i)
for _j in range(_i + 1, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
# calculate center distance between i and j box
dist = (x1[i] - x1[j])**2 + (y1[i] - y1[j])**2
# ovr = inter / areas[j]
if dist <= thresh:
suppressed[j] = 1
return keep[:post_max_size]
from .gaussian import draw_heatmap_gaussian, gaussian_2d, gaussian_radius
__all__ = ['gaussian_2d', 'gaussian_radius', 'draw_heatmap_gaussian']
import numpy as np
import torch
def gaussian_2d(shape, sigma=1):
"""Generate gaussian map.
Args:
shape (list[int]): Shape of the map.
sigma (float): Sigma to generate gaussian map.
Defaults to 1.
Returns:
np.ndarray: Generated gaussian map.
"""
m, n = [(ss - 1.) / 2. for ss in shape]
y, x = np.ogrid[-m:m + 1, -n:n + 1]
h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
h[h < np.finfo(h.dtype).eps * h.max()] = 0
return h
def draw_heatmap_gaussian(heatmap, center, radius, k=1):
"""Get gaussian masked heatmap.
Args:
heatmap (torch.Tensor): Heatmap to be masked.
center (torch.Tensor): Center coord of the heatmap.
radius (int): Radius of gausian.
K (int): Multiple of masked_gaussian. Defaults to 1.
Returns:
torch.Tensor: Masked heatmap.
"""
diameter = 2 * radius + 1
gaussian = gaussian_2d((diameter, diameter), sigma=diameter / 6)
x, y = int(center[0]), int(center[1])
height, width = heatmap.shape[0:2]
left, right = min(x, radius), min(width - x, radius + 1)
top, bottom = min(y, radius), min(height - y, radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = torch.from_numpy(
gaussian[radius - top:radius + bottom,
radius - left:radius + right]).to(heatmap.device,
torch.float32)
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
torch.max(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
return heatmap
def gaussian_radius(det_size, min_overlap=0.5):
"""Get radius of gaussian.
Args:
det_size (tuple[torch.Tensor]): Size of the detection result.
min_overlap (float): Gaussian_overlap. Defaults to 0.5.
Returns:
torch.Tensor: Computed radius.
"""
height, width = det_size
a1 = 1
b1 = (height + width)
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
sq1 = torch.sqrt(b1**2 - 4 * a1 * c1)
r1 = (b1 + sq1) / 2
a2 = 4
b2 = 2 * (height + width)
c2 = (1 - min_overlap) * width * height
sq2 = torch.sqrt(b2**2 - 4 * a2 * c2)
r2 = (b2 + sq2) / 2
a3 = 4 * min_overlap
b3 = -2 * min_overlap * (height + width)
c3 = (min_overlap - 1) * width * height
sq3 = torch.sqrt(b3**2 - 4 * a3 * c3)
r3 = (b3 + sq3) / 2
return min(r1, r2, r3)
...@@ -3,10 +3,11 @@ from mmdet.models.roi_heads.bbox_heads import (BBoxHead, ConvFCBBoxHead, ...@@ -3,10 +3,11 @@ from mmdet.models.roi_heads.bbox_heads import (BBoxHead, ConvFCBBoxHead,
Shared2FCBBoxHead, Shared2FCBBoxHead,
Shared4Conv1FCBBoxHead) Shared4Conv1FCBBoxHead)
from .h3d_bbox_head import H3DBboxHead from .h3d_bbox_head import H3DBboxHead
from .multi_group_head import CenterHead
from .parta2_bbox_head import PartA2BboxHead from .parta2_bbox_head import PartA2BboxHead
__all__ = [ __all__ = [
'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead', 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead',
'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'H3DBboxHead', 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'PartA2BboxHead',
'PartA2BboxHead' 'H3DBboxHead', 'CenterHead'
] ]
import copy
import numpy as np
import torch
from mmcv.cnn import ConvModule, build_conv_layer, kaiming_init
from torch import nn
from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius,
xywhr2xyxyr)
from mmdet3d.models.utils import clip_sigmoid
from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu
from mmdet.core import build_bbox_coder, multi_apply
from ... import builder
from ...builder import HEADS, build_loss
@HEADS.register_module()
class SeparateHead(nn.Module):
"""SeparateHead for CenterHead.
Args:
in_channels (int): Input channels for conv_layer.
heads (dict): Conv information.
head_conv (int): Output channels.
Default: 64.
final_kernal (int): Kernal size for the last conv layer.
Deafult: 1.
init_bias (float): Initial bias. Default: -2.19.
conv_cfg (dict): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict): Config of norm layer.
Default: dict(type='BN2d').
bias (str): Type of bias. Default: 'auto'.
"""
def __init__(self,
in_channels,
heads,
head_conv=64,
final_kernel=1,
init_bias=-2.19,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias='auto',
**kwargs):
super(SeparateHead, self).__init__()
self.heads = heads
self.init_bias = init_bias
for head in self.heads:
classes, num_conv = self.heads[head]
conv_layers = []
for i in range(num_conv - 1):
conv_layers.append(
ConvModule(
in_channels,
head_conv,
kernel_size=final_kernel,
stride=1,
padding=final_kernel // 2,
bias=bias,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg))
conv_layers.append(
build_conv_layer(
conv_cfg,
head_conv,
classes,
kernel_size=final_kernel,
stride=1,
padding=final_kernel // 2,
bias=True))
conv_layers = nn.Sequential(*conv_layers)
self.__setattr__(head, conv_layers)
def init_weights(self):
"""Initialize weights."""
for head in self.heads:
if head == 'heatmap':
self.__getattr__(head)[-1].bias.data.fill_(self.init_bias)
else:
for m in self.__getattr__(head).modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
def forward(self, x):
"""Forward function for SepHead.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg (torch.Tensor): 2D regression value with the \
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the \
shape of [B, 1, H, W].
-dim (torch.Tensor): Size value with the shape \
of [B, 3, H, W].
-rot (torch.Tensor): Rotation value with the \
shape of [B, 2, H, W].
-vel (torch.Tensor): Velocity value with the \
shape of [B, 2, H, W].
-heatmap (torch.Tensor): Heatmap with the shape of \
[B, N, H, W].
"""
ret_dict = dict()
for head in self.heads:
ret_dict[head] = self.__getattr__(head)(x)
return ret_dict
@HEADS.register_module()
class DCNSeperateHead(nn.Module):
r"""DCNSeperateHead for CenterHead.
.. code-block:: none
/-----> DCN for heatmap task -----> heatmap task.
feature
\-----> DCN for regression tasks -----> regression tasks
Args:
in_channels (int): Input channels for conv_layer.
heads (dict): Conv information.
dcn_config (dict): Config of dcn layer.
num_cls (int): Output channels.
Default: 64.
final_kernal (int): Kernal size for the last conv layer.
Deafult: 1.
init_bias (float): Initial bias. Default: -2.19.
conv_cfg (dict): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict): Config of norm layer.
Default: dict(type='BN2d').
bias (str): Type of bias. Default: 'auto'.
""" # noqa: W605
def __init__(self,
in_channels,
num_cls,
heads,
dcn_config,
head_conv=64,
final_kernel=1,
init_bias=-2.19,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias='auto',
**kwargs):
super(DCNSeperateHead, self).__init__()
if 'heatmap' in heads:
heads.pop('heatmap')
# feature adaptation with dcn
# use separate features for classification / regression
self.feature_adapt_cls = build_conv_layer(dcn_config)
self.feature_adapt_reg = build_conv_layer(dcn_config)
# heatmap prediction head
cls_head = [
ConvModule(
in_channels,
head_conv,
kernel_size=3,
padding=1,
conv_cfg=conv_cfg,
bias=bias,
norm_cfg=norm_cfg),
build_conv_layer(
conv_cfg,
head_conv,
num_cls,
kernel_size=3,
stride=1,
padding=1,
bias=bias)
]
self.cls_head = nn.Sequential(*cls_head)
self.init_bias = init_bias
# other regression target
self.task_head = SeparateHead(
in_channels,
heads,
head_conv=head_conv,
final_kernel=final_kernel,
bias=bias)
def init_weights(self):
"""Initialize weights."""
self.cls_head[-1].bias.data.fill_(self.init_bias)
self.task_head.init_weights()
def forward(self, x):
"""Forward function for DCNSepHead.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg (torch.Tensor): 2D regression value with the \
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the \
shape of [B, 1, H, W].
-dim (torch.Tensor): Size value with the shape \
of [B, 3, H, W].
-rot (torch.Tensor): Rotation value with the \
shape of [B, 2, H, W].
-vel (torch.Tensor): Velocity value with the \
shape of [B, 2, H, W].
-heatmap (torch.Tensor): Heatmap with the shape of \
[B, N, H, W].
"""
center_feat = self.feature_adapt_cls(x)
reg_feat = self.feature_adapt_reg(x)
cls_score = self.cls_head(center_feat)
ret = self.task_head(reg_feat)
ret['heatmap'] = cls_score
return ret
@HEADS.register_module
class CenterHead(nn.Module):
"""CenterHead for CenterPoint.
Args:
mode (str): Mode of the head. Default: '3d'.
in_channels (list[int] | int): Channels of the input feature map.
Default: [128].
tasks (list[dict]): Task information including class number
and class names. Default: None.
dataset (str): Name of the dataset. Default: 'nuscenes'.
weight (float): Weight for location loss. Default: 0.25.
code_weights (list[int]): Code weights for location loss. Default: [].
common_heads (dict): Conv information for common heads.
Default: dict().
loss_cls (dict): Config of classification loss function.
Default: dict(type='GaussianFocalLoss', reduction='mean').
loss_bbox (dict): Config of regression loss function.
Default: dict(type='L1Loss', reduction='none').
seperate_head (dict): Config of seperate head. Default: dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3)
share_conv_channel (int): Output channels for share_conv_layer.
Default: 64.
num_heatmap_convs (int): Number of conv layers for heatmap conv layer.
Default: 2.
conv_cfg (dict): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict): Config of norm layer.
Default: dict(type='BN2d').
bias (str): Type of bias. Default: 'auto'.
"""
def __init__(self,
in_channels=[128],
tasks=None,
train_cfg=None,
test_cfg=None,
bbox_coder=None,
common_heads=dict(),
loss_cls=dict(type='GaussianFocalLoss', reduction='mean'),
loss_bbox=dict(
type='L1Loss', reduction='none', loss_weight=0.25),
seperate_head=dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3),
share_conv_channel=64,
num_heatmap_convs=2,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias='auto',
norm_bbox=True):
super(CenterHead, self).__init__()
num_classes = [len(t['class_names']) for t in tasks]
self.class_names = [t['class_names'] for t in tasks]
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.in_channels = in_channels
self.num_classes = num_classes
self.norm_bbox = norm_bbox
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_anchor_per_locs = [n for n in num_classes]
# a shared convolution
self.shared_conv = ConvModule(
in_channels,
share_conv_channel,
kernel_size=3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=bias)
self.task_heads = nn.ModuleList()
for num_cls in num_classes:
heads = copy.deepcopy(common_heads)
heads.update(dict(heatmap=(num_cls, num_heatmap_convs)))
seperate_head.update(
in_channels=share_conv_channel, heads=heads, num_cls=num_cls)
self.task_heads.append(builder.build_head(seperate_head))
def init_weights(self):
"""Initialize weights."""
for task_head in self.task_heads:
task_head.init_weights()
def forward_single(self, x):
"""Forward function for CenterPoint.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
list[dict]: Output results for tasks.
"""
ret_dicts = []
x = self.shared_conv(x)
for task in self.task_heads:
ret_dicts.append(task(x))
return ret_dicts
def forward(self, feats):
"""Forward pass.
Args:
feats (list[torch.Tensor]): Multi-level features, e.g.,
features produced by FPN.
Returns:
tuple(list[dict]): Output results for tasks.
"""
return multi_apply(self.forward_single, feats)
def _gather_feat(self, feat, ind, mask=None):
"""Gather feature map.
Given feature map and index, return indexed feature map.
Args:
feat (torch.tensor): Feature map with the shape of [B, H*W, 10].
ind (torch.Tensor): Index of the ground truth boxes with the
shape of [B, max_obj].
mask (torch.Tensor): Mask of the feature map with the shape
of [B, max_obj]. Default: None.
Returns:
torch.Tensor: Feature map after gathering with the shape
of [B, max_obj, 10].
"""
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
def get_targets(self, gt_bboxes_3d, gt_labels_3d):
"""Generate targets.
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes.
gt_labels_3d (list[torch.Tensor]): Labels of boxes.
Returns:
Returns:
tuple[list[torch.Tensor]]: Tuple of target including \
the following results in order.
- list[torch.Tensor]: Heatmap scores.
- list[torch.Tensor]: Ground truth boxes.
- list[torch.Tensor]: Indexes indicating the \
position of the valid boxes.
- list[torch.Tensor]: Masks indicating which \
boxes are valid.
"""
heatmaps, anno_boxes, inds, masks = multi_apply(
self.get_targets_single, gt_bboxes_3d, gt_labels_3d)
# transpose heatmaps, because the dimension of tensors in each task is
# different, we have to use numpy instead of torch to do the transpose.
heatmaps = np.array(heatmaps).transpose(1, 0).tolist()
heatmaps = [torch.stack(hms_) for hms_ in heatmaps]
# transpose anno_boxes
anno_boxes = np.array(anno_boxes).transpose(1, 0).tolist()
anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes]
# transpose inds
inds = np.array(inds).transpose(1, 0).tolist()
inds = [torch.stack(inds_) for inds_ in inds]
# transpose inds
masks = np.array(masks).transpose(1, 0).tolist()
masks = [torch.stack(masks_) for masks_ in masks]
return heatmaps, anno_boxes, inds, masks
def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
"""Generate training targets for a single sample.
Args:
gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes.
gt_labels_3d (torch.Tensor): Labels of boxes.
Returns:
tuple[list[torch.Tensor]]: Tuple of target including \
the following results in order.
- list[torch.Tensor]: Heatmap scores.
- list[torch.Tensor]: Ground truth boxes.
- list[torch.Tensor]: Indexes indicating the position \
of the valid boxes.
- list[torch.Tensor]: Masks indicating which boxes \
are valid.
"""
device = gt_labels_3d.device
gt_bboxes_3d = torch.cat(
(gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]),
dim=1).to(device)
max_objs = self.train_cfg['max_objs'] * self.train_cfg['dense_reg']
grid_size = torch.tensor(self.train_cfg['grid_size'])
pc_range = torch.tensor(self.train_cfg['point_cloud_range'])
voxel_size = torch.tensor(self.train_cfg['voxel_size'])
feature_map_size = grid_size[:2] // self.train_cfg['out_size_factor']
# reorganize the gt_dict by tasks
task_masks = []
flag = 0
for class_name in self.class_names:
task_masks.append([
torch.where(gt_labels_3d == class_name.index(i) + flag)
for i in class_name
])
flag += len(class_name)
task_boxes = []
task_classes = []
flag2 = 0
for idx, mask in enumerate(task_masks):
task_box = []
task_class = []
for m in mask:
task_box.append(gt_bboxes_3d[m])
# 0 is background for each task, so we need to add 1 here.
task_class.append(gt_labels_3d[m] + 1 - flag2)
task_boxes.append(torch.cat(task_box, axis=0).to(device))
task_classes.append(torch.cat(task_class).long().to(device))
flag2 += len(mask)
draw_gaussian = draw_heatmap_gaussian
heatmaps, anno_boxes, inds, masks = [], [], [], []
for idx, task_head in enumerate(self.task_heads):
heatmap = gt_bboxes_3d.new_zeros(
(len(self.class_names[idx]), feature_map_size[1],
feature_map_size[0]))
anno_box = gt_bboxes_3d.new_zeros((max_objs, 10),
dtype=torch.float32)
ind = gt_labels_3d.new_zeros((max_objs), dtype=torch.int64)
mask = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.uint8)
num_objs = min(task_boxes[idx].shape[0], max_objs)
for k in range(num_objs):
cls_id = task_classes[idx][k] - 1
width = task_boxes[idx][k][3]
length = task_boxes[idx][k][4]
width = width / voxel_size[0] / self.train_cfg[
'out_size_factor']
length = length / voxel_size[1] / self.train_cfg[
'out_size_factor']
if width > 0 and length > 0:
radius = gaussian_radius(
(length, width),
min_overlap=self.train_cfg['gaussian_overlap'])
radius = max(self.train_cfg['min_radius'], int(radius))
# be really careful for the coordinate system of
# your box annotation.
x, y, z = task_boxes[idx][k][0], task_boxes[idx][k][
1], task_boxes[idx][k][2]
coor_x = (
x - pc_range[0]
) / voxel_size[0] / self.train_cfg['out_size_factor']
coor_y = (
y - pc_range[1]
) / voxel_size[1] / self.train_cfg['out_size_factor']
center = torch.tensor([coor_x, coor_y],
dtype=torch.float32,
device=device)
center_int = center.to(torch.int32)
# throw out not in range objects to avoid out of array
# area when creating the heatmap
if not (0 <= center_int[0] < feature_map_size[0]
and 0 <= center_int[1] < feature_map_size[1]):
continue
draw_gaussian(heatmap[cls_id], center_int, radius)
new_idx = k
x, y = center_int[0], center_int[1]
assert (y * feature_map_size[0] + x <
feature_map_size[0] * feature_map_size[1])
ind[new_idx] = y * feature_map_size[0] + x
mask[new_idx] = 1
# TODO: support other outdoor dataset
vx, vy = task_boxes[idx][k][7:]
rot = task_boxes[idx][k][6]
box_dim = task_boxes[idx][k][3:6]
if self.norm_bbox:
box_dim = box_dim.log()
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device),
z.unsqueeze(0), box_dim,
torch.sin(rot).unsqueeze(0),
torch.cos(rot).unsqueeze(0),
vx.unsqueeze(0),
vy.unsqueeze(0)
])
heatmaps.append(heatmap)
anno_boxes.append(anno_box)
masks.append(mask)
inds.append(ind)
return heatmaps, anno_boxes, inds, masks
def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
"""Loss function for CenterHead.
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes.
gt_labels_3d (list[torch.Tensor]): Labels of boxes.
preds_dicts (dict): Output of forward function.
Returns:
dict[str:torch.Tensor]: Loss of heatmap and bbox of each task.
"""
heatmaps, anno_boxes, inds, masks = self.get_targets(
gt_bboxes_3d, gt_labels_3d)
loss_dict = dict()
for task_id, preds_dict in enumerate(preds_dicts):
# heatmap focal loss
preds_dict[0]['heatmap'] = clip_sigmoid(preds_dict[0]['heatmap'])
num_pos = heatmaps[task_id].eq(1).float().sum().item()
loss_heatmap = self.loss_cls(
preds_dict[0]['heatmap'],
heatmaps[task_id],
avg_factor=max(num_pos, 1))
target_box = anno_boxes[task_id]
# reconstruct the anno_box from multiple reg heads
preds_dict[0]['anno_box'] = torch.cat(
(preds_dict[0]['reg'], preds_dict[0]['height'],
preds_dict[0]['dim'], preds_dict[0]['rot'],
preds_dict[0]['vel']),
dim=1)
# Regression loss for dimension, offset, height, rotation
ind = inds[task_id]
num = masks[task_id].float().sum()
pred = preds_dict[0]['anno_box'].permute(0, 2, 3, 1).contiguous()
pred = pred.view(pred.size(0), -1, pred.size(3))
pred = self._gather_feat(pred, ind)
mask = masks[task_id].unsqueeze(2).expand_as(target_box).float()
isnotnan = (~torch.isnan(target_box)).float()
mask *= isnotnan
code_weights = self.train_cfg.get('code_weights', None)
bbox_weights = mask * mask.new_tensor(code_weights)
loss_bbox = self.loss_bbox(
pred, target_box, bbox_weights, avg_factor=(num + 1e-4))
loss_dict[f'task{task_id}.loss_heatmap'] = loss_heatmap
loss_dict[f'task{task_id}.loss_bbox'] = loss_bbox
return loss_dict
def get_bboxes(self, preds_dicts, img_metas, img=None, rescale=False):
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results.
img_metas (list[dict]): Point cloud and image's meta info.
Returns:
list[dict]: Decoded bbox, scores and labels after nms.
"""
rets = []
for task_id, preds_dict in enumerate(preds_dicts):
num_class_with_bg = self.num_classes[task_id]
batch_size = preds_dict[0]['heatmap'].shape[0]
batch_heatmap = preds_dict[0]['heatmap'].sigmoid()
batch_reg = preds_dict[0]['reg']
batch_hei = preds_dict[0]['height']
if self.norm_bbox:
batch_dim = torch.exp(preds_dict[0]['dim'])
else:
batch_dim = preds_dict[0]['dim']
batch_rots = preds_dict[0]['rot'][:, 0].unsqueeze(1)
batch_rotc = preds_dict[0]['rot'][:, 1].unsqueeze(1)
if 'vel' in preds_dict[0]:
batch_vel = preds_dict[0]['vel']
else:
batch_vel = None
temp = self.bbox_coder.decode(
batch_heatmap,
batch_rots,
batch_rotc,
batch_hei,
batch_dim,
batch_vel,
reg=batch_reg,
task_id=task_id)
assert self.test_cfg['nms_type'] in ['circle', 'rotate']
batch_reg_preds = [box['bboxes'] for box in temp]
batch_cls_preds = [box['scores'] for box in temp]
batch_cls_labels = [box['labels'] for box in temp]
if self.test_cfg['nms_type'] == 'circle':
ret_task = []
for i in range(batch_size):
boxes3d = temp[i]['bboxes']
scores = temp[i]['scores']
labels = temp[i]['labels']
centers = boxes3d[:, [0, 1]]
boxes = torch.cat([centers, scores.view(-1, 1)], dim=1)
keep = torch.tensor(
circle_nms(
boxes.detach().cpu().numpy(),
self.test_cfg['min_radius'][task_id],
post_max_size=self.test_cfg['post_max_size']),
dtype=torch.long,
device=boxes.device)
boxes3d = boxes3d[keep]
scores = scores[keep]
labels = labels[keep]
ret = dict(bboxes=boxes3d, scores=scores, labels=labels)
ret_task.append(ret)
rets.append(ret_task)
else:
rets.append(
self.get_task_detections(num_class_with_bg,
batch_cls_preds, batch_reg_preds,
batch_cls_labels, img_metas))
# Merge branches results
num_samples = len(rets[0])
ret_list = []
for i in range(num_samples):
for k in rets[0][i].keys():
if k == 'bboxes':
bboxes = torch.cat([ret[i][k] for ret in rets])
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
bboxes = img_metas[i]['box_type_3d'](
bboxes, self.bbox_coder.code_size)
elif k == 'scores':
scores = torch.cat([ret[i][k] for ret in rets])
elif k == 'labels':
flag = 0
for j, num_class in enumerate(self.num_classes):
rets[j][i][k] += flag
flag += num_class
labels = torch.cat([ret[i][k] for ret in rets])
ret_list.append([bboxes, scores, labels])
return ret_list
def get_task_detections(self, num_class_with_bg, batch_cls_preds,
batch_reg_preds, batch_cls_labels, img_metas):
"""Rotate nms for each task.
Args:
num_class_with_bg (int): Number of classes for the current task.
batch_cls_preds (list[torch.Tensor]): Prediction score with the
shape of [N].
batch_reg_preds (list[torch.Tensor]): Prediction bbox with the
shape of [N, 9].
batch_cls_labels (list[torch.Tensor]): Prediction label with the
shape of [N].
img_metas (list[dict]): Meta information of each sample.
Returns:
list[dict[str: torch.Tensor]]: contains the following keys:
-bboxes (torch.Tensor): Prediction bboxes after nms with the \
shape of [N, 9].
-scores (torch.Tensor): Prediction scores after nms with the \
shape of [N].
-labels (torch.Tensor): Prediction labels after nms with the \
shape of [N].
"""
predictions_dicts = []
post_center_range = self.test_cfg['post_center_limit_range']
if len(post_center_range) > 0:
post_center_range = torch.tensor(
post_center_range,
dtype=batch_reg_preds[0].dtype,
device=batch_reg_preds[0].device)
for i, (box_preds, cls_preds, cls_labels) in enumerate(
zip(batch_reg_preds, batch_cls_preds, batch_cls_labels)):
# Apply NMS in birdeye view
# get highest score per prediction, than apply nms
# to remove overlapped box.
if num_class_with_bg == 1:
top_scores = cls_preds.squeeze(-1)
top_labels = torch.zeros(
cls_preds.shape[0],
device=cls_preds.device,
dtype=torch.long)
else:
top_labels = cls_labels.long()
top_scores = cls_preds.squeeze(-1)
if self.test_cfg['score_threshold'] > 0.0:
thresh = torch.tensor(
[self.test_cfg['score_threshold']],
device=cls_preds.device).type_as(cls_preds)
top_scores_keep = top_scores >= thresh
top_scores = top_scores.masked_select(top_scores_keep)
if top_scores.shape[0] != 0:
if self.test_cfg['score_threshold'] > 0.0:
box_preds = box_preds[top_scores_keep]
top_labels = top_labels[top_scores_keep]
boxes_for_nms = xywhr2xyxyr(img_metas[i]['box_type_3d'](
box_preds[:, :], self.bbox_coder.code_size).bev)
# the nms in 3d detection just remove overlap boxes.
selected = nms_gpu(
boxes_for_nms,
top_scores,
thresh=self.test_cfg['nms_iou_threshold'],
pre_maxsize=self.test_cfg['nms_pre_max_size'],
post_max_size=self.test_cfg['nms_post_max_size'])
else:
selected = []
# if selected is not None:
selected_boxes = box_preds[selected]
selected_labels = top_labels[selected]
selected_scores = top_scores[selected]
# finally generate predictions.
if selected_boxes.shape[0] != 0:
box_preds = selected_boxes
scores = selected_scores
label_preds = selected_labels
final_box_preds = box_preds
final_scores = scores
final_labels = label_preds
if post_center_range is not None:
mask = (final_box_preds[:, :3] >=
post_center_range[:3]).all(1)
mask &= (final_box_preds[:, :3] <=
post_center_range[3:]).all(1)
predictions_dict = dict(
bboxes=final_box_preds[mask],
scores=final_scores[mask],
labels=final_labels[mask])
else:
predictions_dict = dict(
bboxes=final_box_preds,
scores=final_scores,
labels=final_labels)
else:
dtype = batch_reg_preds[0].dtype
device = batch_reg_preds[0].device
predictions_dict = dict(
bboxes=torch.zeros([0, self.bbox_coder.code_size],
dtype=dtype,
device=device),
scores=torch.zeros([0], dtype=dtype, device=device),
labels=torch.zeros([0],
dtype=top_labels.dtype,
device=device))
predictions_dicts.append(predictions_dict)
return predictions_dicts
from .clip_sigmoid import clip_sigmoid
__all__ = ['clip_sigmoid']
import torch
def clip_sigmoid(x, eps=1e-4):
"""Sigmoid function for input feature.
Args:
x (torch.Tensor): Input feature map with the shape of [B, N, H, W].
eps (float): Lower bound of the range to be clamped to. Defaults
to 1e-4.
Returns:
torch.Tensor: Feature map after sigmoid.
"""
y = torch.clamp(x.sigmoid_(), min=eps, max=1 - eps)
return y
...@@ -22,24 +22,32 @@ def boxes_iou_bev(boxes_a, boxes_b): ...@@ -22,24 +22,32 @@ def boxes_iou_bev(boxes_a, boxes_b):
return ans_iou return ans_iou
def nms_gpu(boxes, scores, thresh): def nms_gpu(boxes, scores, thresh, pre_maxsize=None, post_max_size=None):
"""Non maximum suppression on GPU. """Nms function with gpu implementation.
Args: Args:
boxes (torch.Tensor): Input boxes with shape (N, 5). boxes (torch.Tensor): Input boxes with the shape of [N, 5]
scores (torch.Tensor): Scores of predicted boxes with shape (N). ([x1, y1, x2, y2, ry]).
thresh (torch.Tensor): Threshold of non maximum suppression. scores (torch.Tensor): Scores of boxes with the shape of [N].
thresh (int): Threshold.
pre_maxsize (int): Max size of boxes before nms. Default: None.
post_maxsize (int): Max size of boxes after nms. Default: None.
Returns: Returns:
torch.Tensor: Remaining indices with scores in descending order. torch.Tensor: Indexes after nms.
""" """
order = scores.sort(0, descending=True)[1] order = scores.sort(0, descending=True)[1]
if pre_maxsize is not None:
order = order[:pre_maxsize]
boxes = boxes[order].contiguous() boxes = boxes[order].contiguous()
keep = torch.zeros(boxes.size(0), dtype=torch.long) keep = torch.zeros(boxes.size(0), dtype=torch.long)
num_out = iou3d_cuda.nms_gpu(boxes, keep, thresh, boxes.device.index) num_out = iou3d_cuda.nms_gpu(boxes, keep, thresh, boxes.device.index)
return order[keep[:num_out].cuda(boxes.device)].contiguous() keep = order[keep[:num_out].cuda(boxes.device)].contiguous()
if post_max_size is not None:
keep = keep[:post_max_size]
return keep
def nms_normal_gpu(boxes, scores, thresh): def nms_normal_gpu(boxes, scores, thresh):
......
...@@ -323,3 +323,31 @@ def test_anchor_free_box_coder(): ...@@ -323,3 +323,31 @@ def test_anchor_free_box_coder():
assert dir_res_norm.shape == torch.Size([2, 256, 12]) assert dir_res_norm.shape == torch.Size([2, 256, 12])
assert dir_res.shape == torch.Size([2, 256, 12]) assert dir_res.shape == torch.Size([2, 256, 12])
assert size.shape == torch.Size([2, 256, 3]) assert size.shape == torch.Size([2, 256, 3])
def test_centerpoint_bbox_coder():
bbox_coder_cfg = dict(
type='CenterPointBBoxCoder',
post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
max_num=500,
score_threshold=0.1,
pc_range=[-51.2, -51.2],
out_size_factor=4,
voxel_size=[0.2, 0.2])
bbox_coder = build_bbox_coder(bbox_coder_cfg)
batch_dim = torch.rand([2, 3, 128, 128])
batch_hei = torch.rand([2, 1, 128, 128])
batch_hm = torch.rand([2, 2, 128, 128])
batch_reg = torch.rand([2, 2, 128, 128])
batch_rotc = torch.rand([2, 1, 128, 128])
batch_rots = torch.rand([2, 1, 128, 128])
batch_vel = torch.rand([2, 2, 128, 128])
temp = bbox_coder.decode(batch_hm, batch_rots, batch_rotc, batch_hei,
batch_dim, batch_vel, batch_reg, 5)
for i in range(len(temp)):
assert temp[i]['bboxes'].shape == torch.Size([500, 9])
assert temp[i]['scores'].shape == torch.Size([500])
assert temp[i]['labels'].shape == torch.Size([500])
...@@ -8,6 +8,7 @@ from os.path import dirname, exists, join ...@@ -8,6 +8,7 @@ from os.path import dirname, exists, join
from mmdet3d.core.bbox import (Box3DMode, DepthInstance3DBoxes, from mmdet3d.core.bbox import (Box3DMode, DepthInstance3DBoxes,
LiDARInstance3DBoxes) LiDARInstance3DBoxes)
from mmdet3d.models.builder import build_head from mmdet3d.models.builder import build_head
from mmdet.apis import set_random_seed
def _setup_seed(seed): def _setup_seed(seed):
...@@ -689,6 +690,199 @@ def test_h3d_head(): ...@@ -689,6 +690,199 @@ def test_h3d_head():
assert ret_dict['primitive_sem_matching_loss'] >= 0 assert ret_dict['primitive_sem_matching_loss'] >= 0
def test_center_head():
tasks = [
dict(num_class=1, class_names=['car']),
dict(num_class=2, class_names=['truck', 'construction_vehicle']),
dict(num_class=2, class_names=['bus', 'trailer']),
dict(num_class=1, class_names=['barrier']),
dict(num_class=2, class_names=['motorcycle', 'bicycle']),
dict(num_class=2, class_names=['pedestrian', 'traffic_cone']),
]
bbox_cfg = dict(
type='CenterPointBBoxCoder',
post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
max_num=500,
score_threshold=0.1,
pc_range=[-51.2, -51.2],
out_size_factor=8,
voxel_size=[0.2, 0.2])
train_cfg = dict(
grid_size=[1024, 1024, 40],
point_cloud_range=[-51.2, -51.2, -5., 51.2, 51.2, 3.],
voxel_size=[0.1, 0.1, 0.2],
out_size_factor=8,
dense_reg=1,
gaussian_overlap=0.1,
max_objs=500,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2, 1.0, 1.0],
min_radius=2)
test_cfg = dict(
post_center_limit_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
max_per_img=500,
max_pool_nms=False,
min_radius=[4, 12, 10, 1, 0.85, 0.175],
post_max_size=83,
score_threshold=0.1,
pc_range=[-51.2, -51.2],
out_size_factor=8,
voxel_size=[0.2, 0.2],
nms_type='circle')
center_head_cfg = dict(
type='CenterHead',
in_channels=sum([256, 256]),
tasks=tasks,
train_cfg=train_cfg,
test_cfg=test_cfg,
bbox_coder=bbox_cfg,
common_heads=dict(
reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), vel=(2, 2)),
share_conv_channel=64,
norm_bbox=True)
center_head = build_head(center_head_cfg)
x = torch.rand([2, 512, 128, 128])
output = center_head([x])
for i in range(6):
assert output[i][0]['reg'].shape == torch.Size([2, 2, 128, 128])
assert output[i][0]['height'].shape == torch.Size([2, 1, 128, 128])
assert output[i][0]['dim'].shape == torch.Size([2, 3, 128, 128])
assert output[i][0]['rot'].shape == torch.Size([2, 2, 128, 128])
assert output[i][0]['vel'].shape == torch.Size([2, 2, 128, 128])
assert output[i][0]['heatmap'].shape == torch.Size(
[2, tasks[i]['num_class'], 128, 128])
# test get_bboxes
img_metas = [
dict(box_type_3d=LiDARInstance3DBoxes),
dict(box_type_3d=LiDARInstance3DBoxes)
]
ret_lists = center_head.get_bboxes(output, img_metas)
for ret_list in ret_lists:
assert ret_list[0].tensor.shape[0] <= 500
assert ret_list[1].shape[0] <= 500
assert ret_list[2].shape[0] <= 500
def test_dcn_center_head():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and CUDA')
set_random_seed(0)
tasks = [
dict(num_class=1, class_names=['car']),
dict(num_class=2, class_names=['truck', 'construction_vehicle']),
dict(num_class=2, class_names=['bus', 'trailer']),
dict(num_class=1, class_names=['barrier']),
dict(num_class=2, class_names=['motorcycle', 'bicycle']),
dict(num_class=2, class_names=['pedestrian', 'traffic_cone']),
]
voxel_size = [0.2, 0.2, 8]
dcn_center_head_cfg = dict(
type='CenterHead',
mode='3d',
in_channels=sum([128, 128, 128]),
tasks=[
dict(num_class=1, class_names=['car']),
dict(num_class=2, class_names=['truck', 'construction_vehicle']),
dict(num_class=2, class_names=['bus', 'trailer']),
dict(num_class=1, class_names=['barrier']),
dict(num_class=2, class_names=['motorcycle', 'bicycle']),
dict(num_class=2, class_names=['pedestrian', 'traffic_cone']),
],
common_heads={
'reg': (2, 2),
'height': (1, 2),
'dim': (3, 2),
'rot': (2, 2),
'vel': (2, 2)
},
share_conv_channel=64,
bbox_coder=dict(
type='CenterPointBBoxCoder',
post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
max_num=500,
score_threshold=0.1,
pc_range=[-51.2, -51.2],
out_size_factor=4,
voxel_size=voxel_size[:2],
code_size=9),
seperate_head=dict(
type='DCNSeperateHead',
dcn_config=dict(
type='DCN',
in_channels=64,
out_channels=64,
kernel_size=3,
padding=1,
groups=4,
bias=True),
init_bias=-2.19,
final_kernel=3),
loss_cls=dict(type='GaussianFocalLoss', reduction='mean'),
loss_bbox=dict(type='L1Loss', reduction='none', loss_weight=0.25),
norm_bbox=True)
# model training and testing settings
train_cfg = dict(
grid_size=[512, 512, 1],
point_cloud_range=[-51.2, -51.2, -5., 51.2, 51.2, 3.],
voxel_size=voxel_size,
out_size_factor=4,
dense_reg=1,
gaussian_overlap=0.1,
max_objs=500,
min_radius=2,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2, 1.0, 1.0])
test_cfg = dict(
post_center_limit_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
max_per_img=500,
max_pool_nms=False,
min_radius=[4, 12, 10, 1, 0.85, 0.175],
post_max_size=83,
score_threshold=0.1,
pc_range=[-51.2, -51.2],
out_size_factor=4,
voxel_size=voxel_size[:2],
nms_type='circle')
dcn_center_head_cfg.update(train_cfg=train_cfg, test_cfg=test_cfg)
dcn_center_head = build_head(dcn_center_head_cfg).cuda()
x = torch.ones([2, 384, 128, 128]).cuda()
output = dcn_center_head([x])
for i in range(6):
assert output[i][0]['reg'].shape == torch.Size([2, 2, 128, 128])
assert output[i][0]['height'].shape == torch.Size([2, 1, 128, 128])
assert output[i][0]['dim'].shape == torch.Size([2, 3, 128, 128])
assert output[i][0]['rot'].shape == torch.Size([2, 2, 128, 128])
assert output[i][0]['vel'].shape == torch.Size([2, 2, 128, 128])
assert output[i][0]['heatmap'].shape == torch.Size(
[2, tasks[i]['num_class'], 128, 128])
# Test loss.
gt_bboxes_0 = LiDARInstance3DBoxes(torch.rand([10, 9]).cuda(), box_dim=9)
gt_bboxes_1 = LiDARInstance3DBoxes(torch.rand([20, 9]).cuda(), box_dim=9)
gt_labels_0 = torch.randint(1, 11, [10]).cuda()
gt_labels_1 = torch.randint(1, 11, [20]).cuda()
gt_bboxes_3d = [gt_bboxes_0, gt_bboxes_1]
gt_labels_3d = [gt_labels_0, gt_labels_1]
loss = dcn_center_head.loss(gt_bboxes_3d, gt_labels_3d, output)
loss_sum = torch.sum(torch.stack([item for _, item in loss.items()]))
assert torch.isclose(loss_sum, torch.tensor(21972.1230))
# test get_bboxes
img_metas = [
dict(box_type_3d=LiDARInstance3DBoxes),
dict(box_type_3d=LiDARInstance3DBoxes)
]
ret_lists = dcn_center_head.get_bboxes(output, img_metas)
for ret_list in ret_lists:
assert ret_list[0].tensor.shape[0] <= 500
assert ret_list[1].shape[0] <= 500
assert ret_list[2].shape[0] <= 500
def test_ssd3d_head(): def test_ssd3d_head():
if not torch.cuda.is_available(): if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda') pytest.skip('test requires GPU and torch+cuda')
......
import numpy as np
import torch import torch
...@@ -55,3 +56,19 @@ def test_aligned_3d_nms(): ...@@ -55,3 +56,19 @@ def test_aligned_3d_nms():
]) ])
assert torch.all(pick == expected_pick) assert torch.all(pick == expected_pick)
def test_circle_nms():
from mmdet3d.core.post_processing import circle_nms
boxes = torch.tensor([[-11.1100, 2.1300, 0.8823],
[-11.2810, 2.2422, 0.8914],
[-10.3966, -0.3198, 0.8643],
[-10.2906, -13.3159,
0.8401], [5.6518, 9.9791, 0.8271],
[-11.2652, 13.3637, 0.8267],
[4.7768, -13.0409, 0.7810], [5.6621, 9.0422, 0.7753],
[-10.5561, 18.9627, 0.7518],
[-10.5643, 13.2293, 0.7200]])
keep = circle_nms(boxes.numpy(), 0.175)
expected_keep = [1, 2, 3, 4, 5, 6, 7, 8, 9]
assert np.all(keep == expected_keep)
import torch
from mmdet3d.core import draw_heatmap_gaussian
def test_gaussian():
heatmap = torch.zeros((128, 128))
ct_int = torch.tensor([64, 64], dtype=torch.int32)
radius = 2
draw_heatmap_gaussian(heatmap, ct_int, radius)
assert torch.isclose(torch.sum(heatmap), torch.tensor(4.3505), atol=1e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment