Unverified Commit b016d90c authored by twang's avatar twang Committed by GitHub
Browse files

[Feature] Add shape-aware grouping head in SSN (#147)

* Add shape-aware grouping head

* Reformat docstrings

* Remove rewritten get_anchors in shape_aware_head

* Refactor and simplify shape-aware grouping head

* Fix docstring

* Remove fixed preset shape heads

* Add unittest for AlignedAnchor3DRangeGeneratorPerCls

* Add unittest for get bboxes in shape_aware_head

* Add unittest for loss of shape_aware_head

* Fix unstandard docstrings

* Minor fix for a comment

* Add assertion to make sure not all boxes are filtered
parent cec2d8b0
from mmdet.core.anchor import build_anchor_generator from mmdet.core.anchor import build_anchor_generator
from .anchor_3d_generator import (AlignedAnchor3DRangeGenerator, from .anchor_3d_generator import (AlignedAnchor3DRangeGenerator,
AlignedAnchor3DRangeGeneratorPerCls,
Anchor3DRangeGenerator) Anchor3DRangeGenerator)
__all__ = [ __all__ = [
'AlignedAnchor3DRangeGenerator', 'Anchor3DRangeGenerator', 'AlignedAnchor3DRangeGenerator', 'Anchor3DRangeGenerator',
'build_anchor_generator' 'build_anchor_generator', 'AlignedAnchor3DRangeGeneratorPerCls'
] ]
...@@ -323,3 +323,81 @@ class AlignedAnchor3DRangeGenerator(Anchor3DRangeGenerator): ...@@ -323,3 +323,81 @@ class AlignedAnchor3DRangeGenerator(Anchor3DRangeGenerator):
# custom[:] = self.custom_values # custom[:] = self.custom_values
ret = torch.cat([ret, custom], dim=-1) ret = torch.cat([ret, custom], dim=-1)
return ret return ret
@ANCHOR_GENERATORS.register_module()
class AlignedAnchor3DRangeGeneratorPerCls(AlignedAnchor3DRangeGenerator):
"""3D Anchor Generator by range for per class.
This anchor generator generates anchors by the given range for per class.
Note that feature maps of different classes may be different.
Args:
kwargs (dict): Arguments are the same as those in \
:class:`AlignedAnchor3DRangeGenerator`.
"""
def __init__(self, **kwargs):
super(AlignedAnchor3DRangeGeneratorPerCls, self).__init__(**kwargs)
assert len(self.scales) == 1, 'Multi-scale feature map levels are' + \
' not supported currently in this kind of anchor generator.'
def grid_anchors(self, featmap_sizes, device='cuda'):
"""Generate grid anchors in multiple feature levels.
Args:
featmap_sizes (list[tuple]): List of feature map sizes for \
different classes in a single feature level.
device (str): Device where the anchors will be put on.
Returns:
list[list[torch.Tensor]]: Anchors in multiple feature levels. \
Note that in this anchor generator, we currently only \
support single feature level. The sizes of each tensor \
should be [num_sizes/ranges*num_rots*featmap_size, \
box_code_size].
"""
multi_level_anchors = []
anchors = self.multi_cls_grid_anchors(
featmap_sizes, self.scales[0], device=device)
multi_level_anchors.append(anchors)
return multi_level_anchors
def multi_cls_grid_anchors(self, featmap_sizes, scale, device='cuda'):
"""Generate grid anchors of a single level feature map for multi-class
with different feature map sizes.
This function is usually called by method ``self.grid_anchors``.
Args:
featmap_sizes (list[tuple]): List of feature map sizes for \
different classes in a single feature level.
scale (float): Scale factor of the anchors in the current level.
device (str, optional): Device the tensor will be put on.
Defaults to 'cuda'.
Returns:
torch.Tensor: Anchors in the overall feature map.
"""
assert len(featmap_sizes) == len(self.sizes) == len(self.ranges), \
'The number of different feature map sizes anchor sizes and ' + \
'ranges should be the same.'
multi_cls_anchors = []
for i in range(len(featmap_sizes)):
anchors = self.anchors_single_range(
featmap_sizes[i],
self.ranges[i],
scale,
self.sizes[i],
self.rotations,
device=device)
# [*featmap_size, num_sizes/ranges, num_rots, box_code_size]
ndim = len(featmap_sizes[i])
anchors = anchors.view(*featmap_sizes[i], -1, anchors.size(-1))
# [*featmap_size, num_sizes/ranges*num_rots, box_code_size]
anchors = anchors.permute(ndim, *range(0, ndim), ndim + 1)
# [num_sizes/ranges*num_rots, *featmap_size, box_code_size]
multi_cls_anchors.append(anchors.reshape(-1, anchors.size(-1)))
# [num_sizes/ranges*num_rots*featmap_size, box_code_size]
return multi_cls_anchors
...@@ -3,10 +3,11 @@ from .base_conv_bbox_head import BaseConvBboxHead ...@@ -3,10 +3,11 @@ from .base_conv_bbox_head import BaseConvBboxHead
from .centerpoint_head import CenterHead from .centerpoint_head import CenterHead
from .free_anchor3d_head import FreeAnchor3DHead from .free_anchor3d_head import FreeAnchor3DHead
from .parta2_rpn_head import PartA2RPNHead from .parta2_rpn_head import PartA2RPNHead
from .shape_aware_head import ShapeAwareHead
from .ssd_3d_head import SSD3DHead from .ssd_3d_head import SSD3DHead
from .vote_head import VoteHead from .vote_head import VoteHead
__all__ = [ __all__ = [
'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead', 'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead',
'SSD3DHead', 'BaseConvBboxHead', 'CenterHead' 'SSD3DHead', 'BaseConvBboxHead', 'CenterHead', 'ShapeAwareHead'
] ]
This diff is collapsed.
...@@ -40,6 +40,16 @@ class AnchorTrainMixin(object): ...@@ -40,6 +40,16 @@ class AnchorTrainMixin(object):
num_imgs = len(input_metas) num_imgs = len(input_metas)
assert len(anchor_list) == num_imgs assert len(anchor_list) == num_imgs
if isinstance(anchor_list[0][0], list):
# sizes of anchors are different
# anchor number of a single level
num_level_anchors = [
sum([anchor.size(0) for anchor in anchors])
for anchors in anchor_list[0]
]
for i in range(num_imgs):
anchor_list[i] = anchor_list[i][0]
else:
# anchor number of multi levels # anchor number of multi levels
num_level_anchors = [ num_level_anchors = [
anchors.view(-1, self.box_code_size).size(0) anchors.view(-1, self.box_code_size).size(0)
...@@ -112,7 +122,8 @@ class AnchorTrainMixin(object): ...@@ -112,7 +122,8 @@ class AnchorTrainMixin(object):
Returns: Returns:
tuple[torch.Tensor]: Anchor targets. tuple[torch.Tensor]: Anchor targets.
""" """
if isinstance(self.bbox_assigner, list): if isinstance(self.bbox_assigner,
list) and (not isinstance(anchors, list)):
feat_size = anchors.size(0) * anchors.size(1) * anchors.size(2) feat_size = anchors.size(0) * anchors.size(1) * anchors.size(2)
rot_angles = anchors.size(-2) rot_angles = anchors.size(-2)
assert len(self.bbox_assigner) == anchors.size(-3) assert len(self.bbox_assigner) == anchors.size(-3)
...@@ -129,12 +140,11 @@ class AnchorTrainMixin(object): ...@@ -129,12 +140,11 @@ class AnchorTrainMixin(object):
anchor_targets = self.anchor_target_single_assigner( anchor_targets = self.anchor_target_single_assigner(
assigner, current_anchors, gt_bboxes[gt_per_cls, :], assigner, current_anchors, gt_bboxes[gt_per_cls, :],
gt_bboxes_ignore, gt_labels[gt_per_cls], input_meta, gt_bboxes_ignore, gt_labels[gt_per_cls], input_meta,
label_channels, num_classes, sampling) num_classes, sampling)
else: else:
anchor_targets = self.anchor_target_single_assigner( anchor_targets = self.anchor_target_single_assigner(
assigner, current_anchors, gt_bboxes, gt_bboxes_ignore, assigner, current_anchors, gt_bboxes, gt_bboxes_ignore,
gt_labels, input_meta, label_channels, num_classes, gt_labels, input_meta, num_classes, sampling)
sampling)
(labels, label_weights, bbox_targets, bbox_weights, (labels, label_weights, bbox_targets, bbox_weights,
dir_targets, dir_weights, pos_inds, neg_inds) = anchor_targets dir_targets, dir_weights, pos_inds, neg_inds) = anchor_targets
...@@ -170,10 +180,59 @@ class AnchorTrainMixin(object): ...@@ -170,10 +180,59 @@ class AnchorTrainMixin(object):
return (total_labels, total_label_weights, total_bbox_targets, return (total_labels, total_label_weights, total_bbox_targets,
total_bbox_weights, total_dir_targets, total_dir_weights, total_bbox_weights, total_dir_targets, total_dir_weights,
total_pos_inds, total_neg_inds) total_pos_inds, total_neg_inds)
elif isinstance(self.bbox_assigner, list) and isinstance(
anchors, list):
# class-aware anchors with different feature map sizes
assert len(self.bbox_assigner) == len(anchors), \
'The number of bbox assigners and anchors should be the same.'
(total_labels, total_label_weights, total_bbox_targets,
total_bbox_weights, total_dir_targets, total_dir_weights,
total_pos_inds, total_neg_inds) = [], [], [], [], [], [], [], []
current_anchor_num = 0
for i, assigner in enumerate(self.bbox_assigner):
current_anchors = anchors[i]
current_anchor_num += current_anchors.size(0)
if self.assign_per_class:
gt_per_cls = (gt_labels == i)
anchor_targets = self.anchor_target_single_assigner(
assigner, current_anchors, gt_bboxes[gt_per_cls, :],
gt_bboxes_ignore, gt_labels[gt_per_cls], input_meta,
num_classes, sampling)
else: else:
return self.anchor_target_single_assigner( anchor_targets = self.anchor_target_single_assigner(
self.bbox_assigner, anchors, gt_bboxes, gt_bboxes_ignore, assigner, current_anchors, gt_bboxes, gt_bboxes_ignore,
gt_labels, input_meta, label_channels, num_classes, sampling) gt_labels, input_meta, num_classes, sampling)
(labels, label_weights, bbox_targets, bbox_weights,
dir_targets, dir_weights, pos_inds, neg_inds) = anchor_targets
total_labels.append(labels)
total_label_weights.append(label_weights)
total_bbox_targets.append(
bbox_targets.reshape(-1, anchors[i].size(-1)))
total_bbox_weights.append(
bbox_weights.reshape(-1, anchors[i].size(-1)))
total_dir_targets.append(dir_targets)
total_dir_weights.append(dir_weights)
total_pos_inds.append(pos_inds)
total_neg_inds.append(neg_inds)
total_labels = torch.cat(total_labels, dim=0)
total_label_weights = torch.cat(total_label_weights, dim=0)
total_bbox_targets = torch.cat(total_bbox_targets, dim=0)
total_bbox_weights = torch.cat(total_bbox_weights, dim=0)
total_dir_targets = torch.cat(total_dir_targets, dim=0)
total_dir_weights = torch.cat(total_dir_weights, dim=0)
total_pos_inds = torch.cat(total_pos_inds, dim=0)
total_neg_inds = torch.cat(total_neg_inds, dim=0)
return (total_labels, total_label_weights, total_bbox_targets,
total_bbox_weights, total_dir_targets, total_dir_weights,
total_pos_inds, total_neg_inds)
else:
return self.anchor_target_single_assigner(self.bbox_assigner,
anchors, gt_bboxes,
gt_bboxes_ignore,
gt_labels, input_meta,
num_classes, sampling)
def anchor_target_single_assigner(self, def anchor_target_single_assigner(self,
bbox_assigner, bbox_assigner,
...@@ -182,7 +241,6 @@ class AnchorTrainMixin(object): ...@@ -182,7 +241,6 @@ class AnchorTrainMixin(object):
gt_bboxes_ignore, gt_bboxes_ignore,
gt_labels, gt_labels,
input_meta, input_meta,
label_channels=1,
num_classes=1, num_classes=1,
sampling=True): sampling=True):
"""Assign anchors and encode positive anchors. """Assign anchors and encode positive anchors.
...@@ -194,7 +252,6 @@ class AnchorTrainMixin(object): ...@@ -194,7 +252,6 @@ class AnchorTrainMixin(object):
gt_bboxes_ignore (torch.Tensor): Ignored gt bboxes. gt_bboxes_ignore (torch.Tensor): Ignored gt bboxes.
gt_labels (torch.Tensor): Gt class labels. gt_labels (torch.Tensor): Gt class labels.
input_meta (dict): Meta info of each image. input_meta (dict): Meta info of each image.
label_channels (int): The channel of labels.
num_classes (int): The number of classes. num_classes (int): The number of classes.
sampling (bool): Whether to sample anchors. sampling (bool): Whether to sample anchors.
......
...@@ -181,3 +181,58 @@ def test_aligned_anchor_generator(): ...@@ -181,3 +181,58 @@ def test_aligned_anchor_generator():
# set [:56:7] thus it could cover 8 (len(size) * len(rotations)) # set [:56:7] thus it could cover 8 (len(size) * len(rotations))
# anchors on 8 location # anchors on 8 location
assert single_level_anchor[:56:7].allclose(expected_grid_anchors[i]) assert single_level_anchor[:56:7].allclose(expected_grid_anchors[i])
def test_aligned_anchor_generator_per_cls():
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
anchor_generator_cfg = dict(
type='AlignedAnchor3DRangeGeneratorPerCls',
ranges=[[-100, -100, -1.80, 100, 100, -1.80],
[-100, -100, -1.30, 100, 100, -1.30]],
sizes=[[0.63, 1.76, 1.44], [0.96, 2.35, 1.59]],
custom_values=[0, 0],
rotations=[0, 1.57],
reshape_out=False)
featmap_sizes = [(100, 100), (50, 50)]
anchor_generator = build_anchor_generator(anchor_generator_cfg)
# check base anchors
expected_grid_anchors = [[
torch.tensor([[
-99.0000, -99.0000, -1.8000, 0.6300, 1.7600, 1.4400, 0.0000,
0.0000, 0.0000
],
[
-99.0000, -99.0000, -1.8000, 0.6300, 1.7600, 1.4400,
1.5700, 0.0000, 0.0000
]],
device=device),
torch.tensor([[
-98.0000, -98.0000, -1.3000, 0.9600, 2.3500, 1.5900, 0.0000,
0.0000, 0.0000
],
[
-98.0000, -98.0000, -1.3000, 0.9600, 2.3500, 1.5900,
1.5700, 0.0000, 0.0000
]],
device=device)
]]
multi_level_anchors = anchor_generator.grid_anchors(
featmap_sizes, device=device)
expected_multi_level_shapes = [[
torch.Size([20000, 9]), torch.Size([5000, 9])
]]
for i, single_level_anchor in enumerate(multi_level_anchors):
assert len(single_level_anchor) == len(expected_multi_level_shapes[i])
# set [:2*interval:interval] thus it could cover
# 2 (len(size) * len(rotations)) anchors on 2 location
# Note that len(size) for each class is always 1 in this case
for j in range(len(single_level_anchor)):
interval = int(expected_multi_level_shapes[i][j][0] / 2)
assert single_level_anchor[j][:2 * interval:interval].allclose(
expected_grid_anchors[i][j])
...@@ -945,3 +945,102 @@ def test_ssd3d_head(): ...@@ -945,3 +945,102 @@ def test_ssd3d_head():
assert results[0][0].tensor.shape[1] == 7 assert results[0][0].tensor.shape[1] == 7
assert results[0][1].shape[0] >= 0 assert results[0][1].shape[0] >= 0
assert results[0][2].shape[0] >= 0 assert results[0][2].shape[0] >= 0
def test_shape_aware_head_loss():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
bbox_head_cfg = _get_pts_bbox_head_cfg(
'ssn/hv_ssn_secfpn_sbn-all_2x16_2x_lyft-3d.py')
# modify bn config to avoid bugs caused by syncbn
for task in bbox_head_cfg['tasks']:
task['norm_cfg'] = dict(type='BN2d')
from mmdet3d.models.builder import build_head
self = build_head(bbox_head_cfg)
self.cuda()
assert len(self.heads) == 4
assert isinstance(self.heads[0].conv_cls, torch.nn.modules.conv.Conv2d)
assert self.heads[0].conv_cls.in_channels == 64
assert self.heads[0].conv_cls.out_channels == 36
assert self.heads[0].conv_reg.out_channels == 28
assert self.heads[0].conv_dir_cls.out_channels == 8
# test forward
feats = list()
feats.append(torch.rand([2, 384, 200, 200], dtype=torch.float32).cuda())
(cls_score, bbox_pred, dir_cls_preds) = self.forward(feats)
assert cls_score[0].shape == torch.Size([2, 420000, 9])
assert bbox_pred[0].shape == torch.Size([2, 420000, 7])
assert dir_cls_preds[0].shape == torch.Size([2, 420000, 2])
# test loss
gt_bboxes = [
LiDARInstance3DBoxes(
torch.tensor(
[[-14.5695, -6.4169, -2.1054, 1.8830, 4.6720, 1.4840, 1.5587],
[25.7215, 3.4581, -1.3456, 1.6720, 4.4090, 1.5830, 1.5301]],
dtype=torch.float32).cuda()),
LiDARInstance3DBoxes(
torch.tensor(
[[-50.763, -3.5517, -0.99658, 1.7430, 4.4020, 1.6990, 1.7874],
[-68.720, 0.033, -0.75276, 1.7860, 4.9100, 1.6610, 1.7525]],
dtype=torch.float32).cuda())
]
gt_labels = list(torch.tensor([[4, 4], [4, 4]], dtype=torch.int64).cuda())
input_metas = [{
'sample_idx': 1234
}, {
'sample_idx': 2345
}] # fake input_metas
losses = self.loss(cls_score, bbox_pred, dir_cls_preds, gt_bboxes,
gt_labels, input_metas)
assert losses['loss_cls'][0] > 0
assert losses['loss_bbox'][0] > 0
assert losses['loss_dir'][0] > 0
# test empty ground truth case
gt_bboxes = list(torch.empty((2, 0, 7)).cuda())
gt_labels = list(torch.empty((2, 0)).cuda())
empty_gt_losses = self.loss(cls_score, bbox_pred, dir_cls_preds, gt_bboxes,
gt_labels, input_metas)
assert empty_gt_losses['loss_cls'][0] > 0
assert empty_gt_losses['loss_bbox'][0] == 0
assert empty_gt_losses['loss_dir'][0] == 0
def test_shape_aware_head_getboxes():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
bbox_head_cfg = _get_pts_bbox_head_cfg(
'ssn/hv_ssn_secfpn_sbn-all_2x16_2x_lyft-3d.py')
# modify bn config to avoid bugs caused by syncbn
for task in bbox_head_cfg['tasks']:
task['norm_cfg'] = dict(type='BN2d')
from mmdet3d.models.builder import build_head
self = build_head(bbox_head_cfg)
self.cuda()
feats = list()
feats.append(torch.rand([2, 384, 200, 200], dtype=torch.float32).cuda())
# fake input_metas
input_metas = [{
'sample_idx': 1234,
'box_type_3d': LiDARInstance3DBoxes,
'box_mode_3d': Box3DMode.LIDAR
}, {
'sample_idx': 2345,
'box_type_3d': LiDARInstance3DBoxes,
'box_mode_3d': Box3DMode.LIDAR
}]
(cls_score, bbox_pred, dir_cls_preds) = self.forward(feats)
# test get_bboxes
cls_score[0] -= 1.5 # too many positive samples may cause cuda oom
result_list = self.get_bboxes(cls_score, bbox_pred, dir_cls_preds,
input_metas)
assert len(result_list[0][1]) > 0 # ensure not all boxes are filtered
assert (result_list[0][1] > 0.3).all()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment